From a8c5c54532d5e7713c8724945e3aa8c61b235abf Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 18 Aug 2025 15:22:23 +0800 Subject: [PATCH 01/94] upgrade activation_func to transformers v4.54 --- mindone/transformers/activations.py | 61 +++++++++++++---------------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/mindone/transformers/activations.py b/mindone/transformers/activations.py index 8ff354ab6e..691fb55947 100644 --- a/mindone/transformers/activations.py +++ b/mindone/transformers/activations.py @@ -17,34 +17,33 @@ import math from collections import OrderedDict -from functools import partial import mindspore as ms -from mindspore import Tensor, nn, ops +from mindspore import Tensor, nn, mint class PytorchGELUTanh(nn.Cell): """ A fast C implementation of the tanh approximation of the GeLU activation function. See - https://arxiv.org/abs/1606.08415. + https://huggingface.co/papers/1606.08415. This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical match due to rounding errors. """ def construct(self, input: Tensor) -> Tensor: - return ops.gelu(input, approximate="tanh") + return mint.nn.functional.gelu(input, approximate="tanh") class NewGELUActivation(nn.Cell): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see - the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415 """ def construct(self, input: Tensor) -> Tensor: return ( - 0.5 * input * (1.0 + ops.tanh(ops.sqrt(Tensor(2.0 / math.pi)) * (input + 0.044715 * ops.pow(input, 3.0)))) + 0.5 * input * (1.0 + mint.tanh(mint.sqrt(Tensor(2.0 / math.pi)) * (input + 0.044715 * mint.pow(input, 3.0)))) ).to(input.dtype) @@ -52,8 +51,8 @@ class GELUActivation(nn.Cell): """ Original Implementation of the GELU activation function in Google BERT repo when initially created. For information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + - ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ops.pow(x, 3)))) This is now written in C in nn.functional - Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + mint.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * mint.pow(x, 3)))) This is now written in C in nn.functional + Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415 """ def __init__(self, use_gelu_python: bool = False): @@ -61,10 +60,10 @@ def __init__(self, use_gelu_python: bool = False): if use_gelu_python: self.act = self._gelu_python else: - self.act = ops.gelu + self.act = mint.nn.functional.gelu def _gelu_python(self, input: Tensor) -> Tensor: - return input * 0.5 * (1.0 + ops.erf(input / math.sqrt(2.0))) + return input * 0.5 * (1.0 + mint.erf(input / math.sqrt(2.0))) def construct(self, input: Tensor) -> Tensor: return self.act(input) @@ -76,7 +75,7 @@ class FastGELUActivation(nn.Cell): """ def construct(self, input: Tensor) -> Tensor: - return 0.5 * input * (1.0 + ops.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + return 0.5 * input * (1.0 + mint.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) class QuickGELUActivation(nn.Cell): @@ -84,25 +83,21 @@ class QuickGELUActivation(nn.Cell): Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs """ - def __init__(self): - super(QuickGELUActivation, self).__init__() - self.sigmoid = nn.Sigmoid() - - def construct(self, input): - return input * self.sigmoid(1.702 * input) + def construct(self, input: Tensor) -> Tensor: + return input * mint.sigmoid(1.702 * input) class ClippedGELUActivation(nn.Cell): """ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to - https://arxiv.org/abs/2004.09602. + https://huggingface.co/papers/2004.09602. Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + - ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ops.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 + ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ops.pow(x, 3)))). See https://huggingface.co/papers/1606.08415 """ def __init__(self, min: float, max: float): @@ -115,7 +110,7 @@ def __init__(self, min: float, max: float): self.gelu = get_activation("gelu") def construct(self, x: Tensor) -> Tensor: - return ops.clip(self.gelu(x), self.min, self.max) + return mint.clip(self.gelu(x), self.min, self.max) class AccurateGELUActivation(nn.Cell): @@ -131,7 +126,7 @@ def __init__(self): self.precomputed_constant = math.sqrt(2 / math.pi) def construct(self, input: Tensor) -> Tensor: - return 0.5 * input * (1 + ops.tanh(self.precomputed_constant * (input + 0.044715 * ops.pow(input, 3)))) + return 0.5 * input * (1 + mint.tanh(self.precomputed_constant * (input + 0.044715 * mint.pow(input, 3)))) class SiLUActivationFP32(nn.Cell): @@ -149,12 +144,12 @@ def construct(self, x): class MishActivation(nn.Cell): """ - See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also visit the official repository for the paper: https://github.com/digantamisra98/Mish """ def construct(self, input: Tensor) -> Tensor: - return ops.mish(input) + return mint.nn.functional.mish(input) class LinearActivation(nn.Cell): @@ -169,24 +164,24 @@ def construct(self, input: Tensor) -> Tensor: class LaplaceActivation(nn.Cell): """ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See - https://arxiv.org/abs/2209.10655 + https://huggingface.co/papers/2209.10655 Inspired by squared relu, but with bounded range and gradient for better stability """ def construct(self, input, mu=0.707107, sigma=0.282095): input = (input - mu).div(sigma * math.sqrt(2.0)) - return 0.5 * (1.0 + ops.erf(input)) + return 0.5 * (1.0 + mint.erf(input)) class ReLUSquaredActivation(nn.Cell): """ - Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 + Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668v2 """ def construct(self, input): - relu_applied = ops.relu(input) - squared = ops.square(relu_applied) + relu_applied = mint.nn.functional.relu(input) + squared = mint.square(relu_applied) return squared @@ -198,7 +193,7 @@ def __getitem__(self, key): ACT2CLS = { - "gelu": partial(nn.GELU, approximate=False), + "gelu": mint.nn.GELU, "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), "gelu_fast": FastGELUActivation, "gelu_new": NewGELUActivation, @@ -209,13 +204,13 @@ def __getitem__(self, key): "linear": LinearActivation, "mish": MishActivation, "quick_gelu": QuickGELUActivation, - "relu": nn.ReLU, + "relu": mint.nn.ReLU, "relu2": ReLUSquaredActivation, - "relu6": nn.ReLU6, - "sigmoid": nn.Sigmoid, + "relu6": mint.nn.ReLU6, + "sigmoid": mint.nn.Sigmoid, "silu": SiLUActivationFP32, "swish": SiLUActivationFP32, - "tanh": nn.Tanh, + "tanh": mint.nn.Tanh, } ACT2FN = ClassInstantier(ACT2CLS) From c05a70747cdc4b1fad1999cc4be4c856c7c22643 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Tue, 19 Aug 2025 10:07:28 +0800 Subject: [PATCH 02/94] feat(transformers): upgrade attn_mask/rope to 4.54 --- .../transformers/modeling_attn_mask_utils.py | 5 + mindone/transformers/modeling_rope_utils.py | 135 ++++++++---------- 2 files changed, 62 insertions(+), 78 deletions(-) diff --git a/mindone/transformers/modeling_attn_mask_utils.py b/mindone/transformers/modeling_attn_mask_utils.py index f2bf12136b..bb6a21cb88 100644 --- a/mindone/transformers/modeling_attn_mask_utils.py +++ b/mindone/transformers/modeling_attn_mask_utils.py @@ -14,6 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general +`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now, +and will be removed in the future. +""" from dataclasses import dataclass from typing import List, Optional, Tuple, Union diff --git a/mindone/transformers/modeling_rope_utils.py b/mindone/transformers/modeling_rope_utils.py index 8eb1ae9875..42587d6b05 100644 --- a/mindone/transformers/modeling_rope_utils.py +++ b/mindone/transformers/modeling_rope_utils.py @@ -83,7 +83,7 @@ def wrapper(self, x, position_ids): def _compute_default_rope_parameters( - config: Optional[PretrainedConfig] = None, seq_len: Optional[int] = None, **rope_kwargs + config: Optional[PretrainedConfig] = None, seq_len: Optional[int] = None ) -> tuple[Tensor, float]: """ Computes the inverse frequencies according to the original RoPE implementation @@ -92,25 +92,14 @@ def _compute_default_rope_parameters( The model configuration. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE @@ -122,7 +111,6 @@ def _compute_default_rope_parameters( def _compute_linear_scaling_rope_parameters( config: Optional[PretrainedConfig] = None, seq_len: Optional[int] = None, - **rope_kwargs, ) -> tuple["ms.Tensor", float]: """ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev @@ -131,24 +119,14 @@ def _compute_linear_scaling_rope_parameters( The model configuration. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - factor = rope_kwargs["factor"] - elif config is not None: - factor = config.rope_scaling["factor"] + factor = config.rope_scaling["factor"] # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len, **rope_kwargs) + inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len) # Then applies linear scaling to the frequencies. # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so @@ -160,7 +138,6 @@ def _compute_linear_scaling_rope_parameters( def _compute_dynamic_ntk_parameters( config: Optional[PretrainedConfig] = None, seq_len: Optional[int] = None, - **rope_kwargs, ) -> tuple["ms.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla @@ -169,35 +146,30 @@ def _compute_dynamic_ntk_parameters( The model configuration. seq_len (`int`, *optional*): The current sequence length, used to update the dynamic RoPE at inference time. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - max_position_embeddings = rope_kwargs["max_position_embeddings"] - factor = rope_kwargs["factor"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - max_position_embeddings = config.max_position_embeddings - factor = config.rope_scaling["factor"] + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] attention_factor = 1.0 # Unused in this type of RoPE # seq_len: default to max_position_embeddings, e.g. at init time - seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings + if seq_len is None: + seq_len = max_position_embeddings + elif isinstance(seq_len, ms.Tensor): + seq_len = mint.maximum( + seq_len, + ms.tensor(max_position_embeddings, dtype=seq_len.dtype), + ) + else: + seq_len = max(seq_len, max_position_embeddings) # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) @@ -206,39 +178,49 @@ def _compute_dynamic_ntk_parameters( def _compute_yarn_parameters( - config: PretrainedConfig, seq_len: Optional[int] = None, **rope_kwargs + config: PretrainedConfig, seq_len: Optional[int] = None ) -> tuple["ms.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Please refer to the - [original paper](https://arxiv.org/abs/2309.00071) + [original paper](https://huggingface.co/papers/2309.00071) Args: config ([`~transformers.PretrainedConfig`]): The model configuration. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ - # No need to keep BC with yarn, unreleased when this new pattern was created. - if len(rope_kwargs) > 0: - raise ValueError( - f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" - ) - base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) - max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] + attention_factor = config.rope_scaling.get("attention_factor") + mscale = config.rope_scaling.get("mscale") + mscale_all_dim = config.rope_scaling.get("mscale_all_dim") + + # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if "original_max_position_embeddings" in config.rope_scaling: + original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 # Sets the attention factor as suggested in the paper - attention_factor = config.rope_scaling.get("attention_factor") if attention_factor is None: - attention_factor = 0.1 * math.log(factor) + 1.0 + if mscale and mscale_all_dim: + attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) + else: + attention_factor = get_mscale(factor) # Optional config options # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) @@ -270,7 +252,7 @@ def linear_ramp_factor(min, max, dim): inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (factor * pos_freqs) - low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) # Get n-dimensional rotational scaling corrected for extrapolation inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float() @@ -283,7 +265,7 @@ def linear_ramp_factor(min, max, dim): def _compute_longrope_parameters( - config: PretrainedConfig, seq_len: Optional[int] = None, **rope_kwargs + config: PretrainedConfig, seq_len: Optional[int] = None ) -> tuple["ms.Tensor", float]: """ Computes the inverse frequencies with LongRoPE scaling. Please refer to the @@ -293,19 +275,11 @@ def _compute_longrope_parameters( The model configuration. seq_len (`int`, *optional*): The current sequence length. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling - # No need to keep BC with longrope, unreleased when this new pattern was created. - if len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got " - f"{rope_kwargs}" - ) base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 @@ -344,7 +318,7 @@ def _compute_longrope_parameters( def _compute_llama3_parameters( - config: PretrainedConfig, seq_len: Optional[int] = None, **rope_kwargs + config: PretrainedConfig, seq_len: Optional[int] = None ) -> tuple["ms.Tensor", float]: """ Computes the inverse frequencies for llama 3.1. @@ -354,14 +328,12 @@ def _compute_llama3_parameters( The model configuration. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len, **rope_kwargs) + inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len) factor = config.rope_scaling["factor"] # `8` in the original implementation low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation @@ -464,7 +436,14 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} - optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + optional_keys = { + "attention_factor", + "beta_fast", + "beta_slow", + "original_max_position_embeddings", + "mscale", + "mscale_all_dim", + } received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) From 84b3eceb6e8df2b1854b742bab02320ec480f4a1 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Tue, 19 Aug 2025 10:24:04 +0800 Subject: [PATCH 03/94] feat(transformers): upgrade modeling_layers to 4.54 --- mindone/transformers/modeling_layers.py | 262 ++++++++++++++++++++++++ mindone/transformers/utils/generic.py | 21 ++ 2 files changed, 283 insertions(+) create mode 100644 mindone/transformers/modeling_layers.py diff --git a/mindone/transformers/modeling_layers.py b/mindone/transformers/modeling_layers.py new file mode 100644 index 0000000000..1f23264b2b --- /dev/null +++ b/mindone/transformers/modeling_layers.py @@ -0,0 +1,262 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from typing import Optional + +import mindspore as ms +import mindspore.nn as nn +from mindspore import mint + +from .cache_utils import Cache +from .modeling_outputs import ( + BaseModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from .models.auto import AutoModel +from .processing_utils import Unpack +from .utils import TransformersKwargs, logging +from transformers.utils import auto_docstring, can_return_tuple + +logger = logging.get_logger(__name__) + + +class GradientCheckpointingLayer(nn.Cell): + """Base class for layers with gradient checkpointing. + + This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled + (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is + enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`. + + Important: + + When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states) + must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients. + + Example: + + ```python + >>> # Correct - hidden_states passed as positional arg + >>> out = self.layer(hidden_states, attention_mask=attention_mask) + + >>> # Incorrect - hidden_states passed as keyword arg + >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask) + ``` + """ + + gradient_checkpointing = False + + def __call__(self, *args, **kwargs): + if self.gradient_checkpointing and self.training: + raise NotImplementedError + return super().__call__(*args, **kwargs) + + +@auto_docstring +class GenericForSequenceClassification(ABC): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class + setattr(self, self.base_model_prefix, AutoModel.from_config(config)) + self.score = mint.nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutputWithPast: + transformer_outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(ms.int32) + token_indices = mint.arange(input_ids.shape[-1], dtype=ms.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[mint.arange(batch_size), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class GenericForQuestionAnswering(ABC): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class + setattr(self, self.base_model_prefix, AutoModel.from_config(config)) + self.qa_outputs = mint.nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return getattr(self, self.base_model_prefix).embed_tokens + + def set_input_embeddings(self, value): + getattr(self, self.base_model_prefix).embed_tokens = value + + @can_return_tuple + @auto_docstring + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + start_positions: Optional[ms.Tensor] = None, + end_positions: Optional[ms.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> QuestionAnsweringModelOutput: + outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class GenericForTokenClassification(ABC): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class + setattr(self, self.base_model_prefix, AutoModel.from_config(config)) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs, + ) -> TokenClassifierOutput: + outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index be560c7988..d571e4a53e 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -445,6 +445,27 @@ def mindspore_float(x): return x.to(ms.float32) if isinstance(x, ms.Tensor) else int(x) +class TransformersKwargs(TypedDict, total=False): + """ + Keyword arguments to be passed to the loss function + + Attributes: + num_items_in_batch (`Optional[ms.Tensor]`, *optional*): + Number of items in the batch. It is recommended to pass it when + you are doing gradient accumulation. + output_hidden_states (`Optional[bool]`, *optional*): + Most of the models support outputing all hidden states computed during the forward pass. + output_attentions (`Optional[bool]`, *optional*): + Turn this on to return the intermediary attention scores. + output_router_logits (`Optional[bool]`, *optional*): + For MoE models, this allows returning the router logits to compute the loss. + """ + + num_items_in_batch: Optional["ms.Tensor"] + output_hidden_states: Optional[bool] + output_attentions: Optional[bool] + output_router_logits: Optional[bool] + def filter_out_non_signature_kwargs(extra: Optional[list] = None): """ From 6b94c2d2dd9a0cf1aaf119c8c593ff629ce90929 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:16:25 +0800 Subject: [PATCH 04/94] feat(transformers): upgrade cache_utils to 4.54 --- mindone/transformers/__init__.py | 3 +- mindone/transformers/cache_utils.py | 1745 ++++++++++++++++++--------- 2 files changed, 1201 insertions(+), 547 deletions(-) diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index e4b7d051d0..546edbc7f6 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -21,10 +21,11 @@ # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # in the namespace without actually importing anything (and especially none of the backends). -__version__ = "4.50.0" +__version__ = "4.54.0" import transformers from packaging import version +from .cache_utils import * # Feature Extractor from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .image_processing_base import ImageProcessingMixin diff --git a/mindone/transformers/cache_utils.py b/mindone/transformers/cache_utils.py index 365a01c4f4..b95ed93cb3 100644 --- a/mindone/transformers/cache_utils.py +++ b/mindone/transformers/cache_utils.py @@ -4,9 +4,13 @@ Cache utils. """ import copy +import functools +import inspect import json import os -from typing import Any, Dict, List, Optional, Tuple, Union +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np from transformers.configuration_utils import PretrainedConfig @@ -145,312 +149,751 @@ def reset(past_key_values): return past_key_values +class CacheLayerMixin(ABC): + """Base, abstract class for a single layer's cache.""" -class Cache(nn.Cell): + is_compileable = False + + def __init__(self): + self.keys, self.values = None, None + + @abstractmethod + def update( + self, + key_states: ms.Tensor, + value_states: ms.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: ... + + @abstractmethod + def get_seq_length(self, cache_position=None) -> int: ... + + @abstractmethod + def get_max_cache_shape(self) -> int: ... + + @abstractmethod + def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]: ... + + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + self.keys.zero_() + self.values.zero_() + + def reorder_cache(self, beam_idx: ms.Tensor) -> tuple[ms.Tensor, ms.Tensor]: + """Reorders this layer's cache for beam search.""" + if self.keys.numel(): + self.keys = self.keys.index_select(0, beam_idx) + if self.values.numel(): + self.values = self.values.index_select(0, beam_idx) + +class DynamicLayer(CacheLayerMixin): """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. + A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. """ - is_compileable = False + is_sliding = False def update( self, key_states: ms.Tensor, value_states: ms.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[ms.Tensor, ms.Tensor]: + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states`. Parameters: key_states (`ms.Tensor`): The new key states to cache. value_states (`ms.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. + cache_kwargs (`dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`. Return: A tuple containing the updated key and value states. """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") + if self.keys is None: + self.keys = key_states + self.values = value_states + else: + self.keys = mint.cat([self.keys, key_states], dim=-2) + self.values = mint.cat([self.values, value_states], dim=-2) + return self.keys, self.values - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states, if there is any.""" - raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_length() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length + def get_seq_length(self, cache_position=None) -> int: + """Returns the sequence length of the cached states.""" + if self.keys is None or self.keys.numel() == 0: + return 0 + return self.keys.shape[-2] - def reorder_cache(self, beam_idx: ms.Tensor): + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + return -1 + + def reorder_cache(self, beam_idx: ms.Tensor) -> None: """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] != []: - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - if self.value_cache[layer_idx] != []: - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.index_select(0, beam_idx) + self.values = self.values.index_select(0, beam_idx) - @property - def seen_tokens(self): - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None + def crop(self, max_length: int) -> None: + """ + Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. + """ + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + if self.get_seq_length() <= max_length: + return -class StaticCache(Cache): + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[..., :max_length, :] + self.values = self.values[..., :max_length, :] + + def batch_repeat_interleave(self, repeats: int) -> None: + """Repeat the cache `repeats` times in the batch dimension.""" + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.repeat_interleave(repeats, dim=0) + self.values = self.values.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: ms.Tensor) -> None: + """Only keep the `indices` in the batch dimension of the cache.""" + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[indices, ...] + self.values = self.values[indices, ...] + + def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the mask""" + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + @classmethod + def from_tensors(cls, keys: ms.Tensor, values: ms.Tensor) -> "DynamicLayer": + """ + Build a `DynamicLayer` instance from pre-existing key/value tensors. + + Args: + keys (`ms.Tensor`): + Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + values (`ms.Tensor`): + Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + + Returns: + `DynamicLayer`: The newly constructed layer whose internal cache directly references + the supplied tensors. + """ + layer = cls() + layer.keys = keys + layer.values = values + return layer + + +class StaticLayer(CacheLayerMixin): """ - Static Cache class to be used with `static shape`. + A static cache layer that stores the Key and Value states as static tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + It allocates its full backing tensors up-front and mutates them in-place. Built for `mindspore.jit` support. - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - dtype (*optional*, defaults to `ms.float32`): - The default `dtype` to use when initializing the layer. + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. """ is_compileable = True + is_sliding = False - def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, dtype=None) -> None: - super().__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) + def __init__( + self, + max_cache_len: int, + batch_size: int, + num_heads: int, + head_dim: int, + dtype: ms.Type = ms.float32, + sliding_window: Optional[int] = None, + ): + """ + Args: + max_cache_len (`int`): + Maximum number of tokens that can be stored, used for tensor preallocation. + batch_size (`int`): + Maximum batch size the cache is pre-allocated for. + num_heads (`int`): + Number of attention heads. + head_dim (`int`): + Per-head hidden dimension. + dtype (`ms.Type`, defaults to `ms.float32`): + Data type of the cache tensors. + + Notes: + Static layers allocate their full backing tensors up-front and mutate them + in-place. See the documentation of `Cache` for shared helper methods that + operate uniformly across all layer types. + """ + self.max_cache_len = max_cache_len + self.max_batch_size = batch_size + self.num_heads = num_heads + self.head_dim = head_dim + self.dtype = dtype - self.dtype = dtype if dtype is not None else ms.float32 - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads + self.keys = mint.zeros( + (batch_size, num_heads, self.max_cache_len, head_dim), + dtype=dtype, ) + self.values = mint.zeros( + (batch_size, num_heads, self.max_cache_len, head_dim), + dtype=dtype, + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + # fixme there is no implementation for torch._dynamo.mark_static_address - key_cache: List[ms.Parameter] = [] - value_cache: List[ms.Parameter] = [] - cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - for _layer_index in range(config.num_hidden_layers): - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - new_layer_key_cache = ms.Parameter( - ms.Tensor(np.zeros(cache_shape), dtype=self.dtype), - name=f"key_cache_{_layer_index}", - requires_grad=False, - ) - new_layer_value_cache = ms.Parameter( - ms.Tensor(np.zeros(cache_shape), dtype=self.dtype), - name=f"value_cache_{_layer_index}", - requires_grad=False, - ) - key_cache.append(new_layer_key_cache) - value_cache.append(new_layer_value_cache) - - self.key_cache = ms.ParameterTuple(key_cache) - self.value_cache = ms.ParameterTuple(value_cache) + def get_max_cache_shape(self) -> int: + """Return the maximum cache shape of the cache""" + return self.max_cache_len def update( self, key_states: ms.Tensor, value_states: ms.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[ms.Tensor, ms.Tensor]: + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Update the static cache tensors in place. - Parameters: - key_states (`ms.Tensor`): - The new key states to cache. - value_states (`ms.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. + Args: + key_states (`ms.Tensor`): The new key states to cache. + value_states (`ms.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - Return: - A tuple containing the updated key and value states. + Returns: + tuple[`ms.Tensor`, `ms.Tensor`]: The updated key and value states. """ - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) + # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. + self.keys.copy_(key_states) + self.values.copy_(value_states) else: - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" + # Generation phase. Update specific positions. + # Use index_copy_ for in-place update (compile-friendly). + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values + + def get_seq_length(self, cache_position=None) -> int: + """Returns the sequence length of the cached states.""" + if cache_position is not None: + return int(cache_position[-1] + 1) # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - return (self.key_cache[layer_idx][0, 0].any(axis=-1)).sum() + seq_length = (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0 + return seq_length - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states.""" - return self.max_cache_len + def reorder_cache(self, beam_idx: ms.Tensor) -> None: + """Reorders the cache for beam search, given the selected beam indices.""" + self.keys = self.keys.index_select(0, beam_idx) + self.values = self.values.index_select(0, beam_idx) - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - ops.assign(self.key_cache[layer_idx], ms.Tensor(0.0)) - ops.assign(self.value_cache[layer_idx], ms.Tensor(0.0)) + def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" + kv_offset = 0 + kv_length = self.max_cache_len + return kv_length, kv_offset -class CacheConfig: +class SlidingWindowLayer(StaticLayer): """ - Base class for cache configs + A static cache layer that implements sliding window attention caching. + + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. """ - cache_implementation: None + is_sliding = True - @classmethod - def from_dict(cls, config_dict, **kwargs): + def __init__(self, sliding_window, *args, **kwargs): """ - Constructs a CacheConfig instance from a dictionary of parameters. Args: - config_dict (Dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - Returns: - CacheConfig: Instance of CacheConfig constructed from the dictionary. + sliding_window (`int`): + Effective window size: number of tokens that are kept on each update call. """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config + max_cache_len = kwargs.pop("max_cache_len", None) + max_cache_len = min(sliding_window, max_cache_len) if max_cache_len is not None else sliding_window + super().__init__(*args, max_cache_len=max_cache_len, *args, **kwargs) - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file - def to_json_file(self, json_file_path: Union[str, os.PathLike]): + def update( + self, + key_states: ms.Tensor, + value_states: ms.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: """ - Save this instance to a JSON file. + Update the sliding window cache tensors in place. Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `QuantizationConfig()` is serialized to JSON file. + key_states (`ms.Tensor`): The new key states to cache. + value_states (`ms.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`ms.Tensor`, `ms.Tensor`]: The updated key and value states. """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + if cache_position is None: + raise ValueError("`cache_position` must be provided for SlidingWindowLayer.") - writer.write(json_string) + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict - def to_dict(self) -> Dict[str, Any]: + # Handle prefill phase when prompt length > sliding_window_size. + # Note that we store cropped key/value states in the cache but return the full key/value states. + if cache_position.shape[0] > self.max_cache_len: + new_k = key_states[:, :, -self.max_cache_len :, :] + new_v = value_states[:, :, -self.max_cache_len :, :] + self.keys.copy_(new_k) + self.values.copy_(new_v) + return key_states, value_states + + # Sliding window logic for generation phase or prefill < window + slicing = mint.arange(self.max_cache_len) + current_seq_len = cache_position[-1] + 1 # Use last position to determine current length + to_shift = current_seq_len > self.max_cache_len + indices = (slicing + to_shift.sum()) % self.max_cache_len + + k_out_shifted = self.keys[:, :, indices] + v_out_shifted = self.values[:, :, indices] + + # Clamp cache_position to determine the *target index* within the shifted cache view + update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) + + try: + k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) + v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) + except NotImplementedError: + # Fallback for MPS: clone and modify the clone + k_out_updated = k_out_shifted.clone() + v_out_updated = v_out_shifted.clone() + k_out_updated[:, :, update_position] = key_states + v_out_updated[:, :, update_position] = value_states + + self.keys.copy_(k_out_updated) + self.values.copy_(v_out_updated) + return self.keys, self.values + + def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + + kv_offset = mint.clamp(first_cache_position - self.max_cache_len + 1, min=0) + # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns + kv_length = max(query_length, self.max_cache_len) + return kv_length, kv_offset + + +class ChunkedSlidingLayer(SlidingWindowLayer): + """ + An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4. + + See `SlidingWindowLayer` for details on common methods that are implemented by all cache layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cumulative_length = 0 + + def update( + self, + key_states: ms.Tensor, + value_states: ms.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + if cache_position is None: + raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.") + + cumulative_length = self.cumulative_length + self.cumulative_length += key_states.shape[-2] + is_full = cumulative_length >= self.max_cache_len + + if is_full: + full_key_states = mint.cat((self.keys[:, :, 1:, :], key_states), dim=-2) + full_value_states = mint.cat((self.values[:, :, 1:, :], value_states), dim=-2) + # Fast decoding path -> here as the effective size is still sliding window, it is extremely important + # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address + # in memory (the values are the same as the full states, but not the address!!) + if key_states.shape[-2] == 1: + self.keys.copy_(full_key_states) + self.values.copy_(full_value_states) + return self.keys, self.values + elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: + if cumulative_length == 0: + full_key_states = key_states + full_value_states = value_states + else: + full_key_states = mint.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2) + full_value_states = mint.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2) + else: + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values + + self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) + self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) + return full_key_states, full_value_states + + def reset(self) -> None: + super().reset() + self.cumulative_length = 0 + + def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]: + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + sliding_window = self.max_cache_len + + kv_offset = mint.clamp(first_cache_position - sliding_window + 1, min=0) + # This is the true general case for any Cache using local attention (sliding or chunked) + if first_cache_position >= sliding_window: + # Here the Cache is already full + kv_length = sliding_window + query_length - 1 + elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window: + # Here the Cache becomes full with the new input + kv_length = first_cache_position + query_length + else: + # Here the Cache is still smaller than the local size, but we return the local size as it's static + kv_length = sliding_window + return kv_length, kv_offset + + +class CacheProcessor: + """ + Base class for cache processors. It defines a pre-update and post-update methods that are called before and after the cache update. + This class should be subclassed. + """ + + def __init__(self, cache: "Cache", **kwargs) -> None: """ - Serializes this instance to a Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + Initialize the processor and perform compatibility checks with the cache. + + Args: + cache (`Cache`): The cache instance this processor will be applied to. + **kwargs: Additional arguments that may be needed for initialization. """ - return copy.deepcopy(self.__dict__) + raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value + def pre_update( + self, + cache: "Cache", + key_states: ms.Tensor, + value_states: ms.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: + """ + Function called before the cache update. Can modify the key/value states. - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" + Args: + cache (`Cache`): The cache instance. + key_states (`ms.Tensor`): The new key states to cache. + value_states (`ms.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - def to_json_string(self): - """ - Serializes this instance to a JSON formatted string. Returns: - str: JSON formatted string representing the configuration instance. + The modified key and value states. """ - return json.dumps(self.__dict__, indent=2) + "\n" + return key_states, value_states - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update - def update(self, **kwargs): + def post_update( + self, + cache: "Cache", + key_tensors: ms.Tensor, + value_tensors: ms.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: """ - Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, - returning all the unused kwargs. + Function called after the cache update. Can process the cached data. Args: - kwargs (`Dict[str, Any]`): - Dictionary of attributes to tentatively update this class. + cache (`Cache`): The cache instance. + key_states (`ms.Tensor`): The key states that were cached. + value_states (`ms.Tensor`): The value states that were cached. + layer_idx (`int`): The index of the layer that was updated. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. Returns: - `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + The final key and value states to return to the model. """ - to_remove = [] - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - to_remove.append(key) + return key_tensors, value_tensors - # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs +class OffloadedCacheProcessor(CacheProcessor): + """ + A cache processor that offloads cache tensors to conserve accelerator memory. -class DynamicCache(Cache): + This processor manages moving cache tensors between accelerator and CPU memory, + using asynchronous prefetching to minimize performance impact. Works with both + dynamic and static layers. """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. + def __init__(self, cache: "Cache", **kwargs): + raise NotImplementedError + + +class QuantizedCacheProcessor(CacheProcessor): """ + A cache processor that applies quantization to cache tensors to reduce memory usage. - def __init__(self, num_hidden_layers: Optional[int] = None) -> None: - # in hf transformers there is no `num_hidden_layers` but `_distributed_cache_data` - # it was originally added for compatibility with `torch.distributed` (DDP). See #36121 - # in mindspore there is no DDP, so we keep `num_hidden_layers` - super().__init__() - if num_hidden_layers is None: - self.key_cache: List[ms.Tensor] = [] - self.value_cache: List[ms.Tensor] = [] + This processor quantizes cache tensors after they are stored, maintaining a residual + length in original precision and quantizing older tokens. + """ + + def __init__( + self, + cache: "Cache", + backend: str = "quanto", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: ms.Type = ms.float16, + ): + """ + Parameters: + backend (`str`, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`int`, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`int`, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`int`, defaults to 128): + Length of the residual cache which will always be stored in original precision. + Defaults to 128. + compute_dtype (`ms.Type`, defaults to `ms.float16`): + The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + """ + raise NotImplementedError + + +class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `quanto` as a backend to perform quantization. + Current implementation supports `int2` and `int4` dtypes only. + """ + + def __init__( + self, + cache: "Cache", + backend: str = "quanto", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: ms.Type = ms.float16, + ) -> None: + """Initialize the quanto quantization processor.""" + super().__init__( + cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype + ) + + raise NotImplementedError + + +class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `HQQ` as a backend to perform quantization. + Current implementation supports `int2`, `int4`, `int8` dtypes. + """ + + def __init__( + self, + cache: "Cache", + backend: str = "quanto", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: ms.Type = ms.float16, + ) -> None: + """Initialize the HQQ quantization processor.""" + super().__init__( + cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype + ) + raise NotImplementedError + + +def apply_processors( + fn: Callable[..., tuple[ms.Tensor, ms.Tensor]], +) -> Callable[..., tuple[ms.Tensor, ms.Tensor]]: + @functools.wraps(fn) + def _wrapped_update( + self, + key_states: ms.Tensor, + value_states: ms.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: + """ + Wrapper around the update method to apply cache processors. + """ + if self.cache_processor is not None: + key_states, value_states = self.cache_processor.pre_update( + self, key_states, value_states, layer_idx, cache_kwargs + ) + + key_tensors, value_tensors = fn(self, key_states, value_states, layer_idx, cache_kwargs) + + if self.cache_processor is not None: + key_tensors, value_tensors = self.cache_processor.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + + return key_tensors, value_tensors + + return _wrapped_update + + +class KeyValuesWrapper: + """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache. + This allows for BC access and writing, e.g., cache.key_cache[idx] = ... + Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0""" + + def __init__(self, layers, cache_type="keys"): + self.layers = layers + self.cache_type = cache_type + + def __getitem__(self, idx): + if isinstance(idx, slice): + return [getattr(layer, self.cache_type) for layer in self.layers[idx]] + return getattr(self.layers[idx], self.cache_type) + + def __setitem__(self, idx, value): + if isinstance(idx, slice): + for layer, val in zip(self.layers[idx], value): + setattr(layer, self.cache_type, val) else: - self.key_cache: List[ms.Tensor] = [[] for _ in range(num_hidden_layers)] - self.value_cache: List[ms.Tensor] = [[] for _ in range(num_hidden_layers)] - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + setattr(self.layers[idx], self.cache_type, value) + + def __len__(self): + return len(self.layers) + + def __iter__(self): + for layer in self.layers: + yield getattr(layer, self.cache_type) + + def __bool__(self): + return bool(self.layers) + - def __getitem__(self, layer_idx: int) -> List[Tuple[ms.Tensor]]: +class Cache: + """ + Base container for per-layer key/value caches. + + A `Cache` behaves like a list of `CacheLayerMixin` objects, one per model layer. + Sub-classes such as `DynamicCache`, `StaticCache`, or `SlidingWindowCache` + simply pre-select which `CacheLayerMixin` class to use and may attach a + `CacheProcessor` (off-loading, quantization). + + Example + ------- + ```python + from mindone.transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache + + model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + inputs = tok("Hello", return_tensors="np") + for key in inputs.keys(): + inputs[key] = ms.tensor(inputs[key]) + + cache = DynamicCache() + outputs = model(**inputs, past_key_values=cache, use_cache=True) + ``` + + Parameters: + layer_classes (`type[CacheLayerMixin]` or `list[type[CacheLayerMixin]]`): + A list of `CacheLayerMixin` classes to instantiate for the cache. If only a `CacheLayerMixin` class is + provided, then it is used for all layers. + config (`PretrainedConfig`, *optional*): + Model configuration used to infer number of layers, head sizes, default + device/dtype, etc. + cache_processor (`CacheProcessor` or `str`, *optional*): + Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") + or a CacheProcessor class. + max_batch_size (`int`, *optional*): Maximum batch size for static caches. + max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are + clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`. + dtype (`ms.Type`, *optional*): Data type for cache tensors. + tp_size (`int`, *optional*): Tensor parallel size to adjust the number of key/value heads. + + Additional keyword arguments are forwarded to the chosen layers constructor(s) and CacheProcessors. See the + documentation of the relevant `CacheLayerMixin` class and `CacheProcessor` class for more details. + """ + + def __init__( + self, + layer_classes: Union[list[type[CacheLayerMixin]], type[CacheLayerMixin]], + config: Optional[PretrainedConfig] = None, + cache_processor: Optional[Union[str, type[CacheProcessor]]] = None, + max_batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + dtype: Optional[ms.Type] = None, + tp_size: Optional[int] = None, + **kwargs, + ): + self.layers: list[CacheLayerMixin] = [] + self.layer_classes = layer_classes + + processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor + kwargs.update( + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + dtype=dtype, + tp_size=tp_size, + ) + processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) + + self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs) + self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) + + self.append_new_layers(self.num_hidden_layers - 1) + self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None + + def __getitem__(self, layer_idx: int) -> tuple[ms.Tensor, ms.Tensor]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the sequence length. """ - if layer_idx < len(self): - return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + if layer_idx < len(self.layers): + return self.layers[layer_idx].keys, self.layers[layer_idx].values else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + raise KeyError( + f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" + ) def __iter__(self): """ @@ -458,22 +901,58 @@ def __iter__(self): keys and values """ for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) def __len__(self): """ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers in the model. """ - return len(self.key_cache) + # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__ + if getattr(self, "layers", None) is None: + if getattr(self, "key_cache", None) is not None: + return len(self.key_cache) + return 0 + # Empty dynamic caches initialize an empty layer to be ready for first update + dynamic_empty = ( + getattr(self, "layers", None) is not None + and len(self.layers) == 1 + and isinstance(self.layers[0], DynamicLayer) + and self.layers[0].keys is None + ) + return len(self.layers) if not dynamic_empty else 0 + + def __repr__(self): + return f"{self.__class__.__name__}(layers={self.layers})" + + def append_new_layers(self, layer_idx: int) -> None: + """ + Appends layers to the cache until the layer `layer_idx` is reached. + Used for preallocation in static caches and on the fly in dynamic caches. + Args: + layer_idx (`int`): + The index of the layer to append. + """ + while len(self.layers) <= layer_idx: + kwargs = self.layer_init_kwargs.copy() + if self.layer_init_kwargs.get("layer_device_map", None) is not None: + kwargs["device"] = kwargs.pop("layer_device_map")[layer_idx] + + new_layer_class = ( + self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes + ) + new_layer = new_layer_class(**kwargs) + self.layers.append(new_layer) + + @apply_processors def update( self, key_states: ms.Tensor, value_states: ms.Tensor, layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[ms.Tensor, ms.Tensor]: + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -484,58 +963,164 @@ def update( The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + cache_kwargs (`dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. Return: A tuple containing the updated key and value states. """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append([]) - self.value_cache.append([]) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - # content on layer cache can be a tensor and checking not tensor causes errors - # so we explicitly check for the empty list - elif len(self.key_cache[layer_idx]) == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = ops.cat([self.key_cache[layer_idx], key_states], axis=-2) - self.value_cache[layer_idx] = ops.cat([self.value_cache[layer_idx], value_states], axis=-2) + self.append_new_layers(layer_idx) + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []): + def get_seq_length(self, layer_idx: int = 0, cache_position=None) -> int: + """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" + if layer_idx >= len(self.layers): return 0 - return self.key_cache[layer_idx].shape[-2] + # Hack since QuantizedCache messes with keys shape as it becomes the residual cache + if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): + return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length(cache_position) + return self.layers[layer_idx].get_seq_length(cache_position) + + def get_mask_sizes(self, cache_position: ms.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position) + return kv_length, kv_offset + + @property + def key_cache(self) -> KeyValuesWrapper: + """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`""" + logger.warning_once( + "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." + ) + return KeyValuesWrapper(self.layers, "keys") + + @property + def value_cache(self) -> KeyValuesWrapper: + """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`""" + logger.warning_once( + "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." + ) + return KeyValuesWrapper(self.layers, "values") + + ### Wrappers for layer operations and properties ### + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length.""" + return self.layers[layer_idx].get_max_cache_shape() + + def reset(self): + """Recursively reset all layers tensors""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reset() + + def reorder_cache(self, beam_idx: ms.Tensor): + """Reorder the cache for beam search""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reorder_cache(beam_idx) + + def crop(self, max_length: int): + """Crop the cache to the given length""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].crop(max_length) + + def batch_repeat_interleave(self, repeats: int): + """Repeat and interleave the cache""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: ms.Tensor): + """Select indices from the cache""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].batch_select_indices(indices) + + @property + def max_batch_size(self) -> int: + """Return the maximum batch size of the cache""" + values = [layer.max_batch_size for layer in self.layers] + if len(set(values)) > 1: + raise ValueError(f"Max batch size is not consistent across layers: {values}") + return values[0] + + @property + def max_cache_len(self) -> int: + """Return the maximum cache length of the cache""" + values = [layer.max_cache_len for layer in self.layers] + return max(values) + + @property + def is_compileable(self) -> bool: + """Return whether the cache is compileable""" + return all(layer.is_compileable for layer in self.layers) + + @property + def is_sliding(self) -> list[bool]: + """Return whether the layers of the cache are sliding window""" + return [getattr(layer, "is_sliding", False) for layer in self.layers] + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + See `Cache` for details on common methods that are implemented by all cache classes. + + Example: - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" - return None + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicCache() + ``` + """ - def to_legacy_cache(self) -> Tuple[Tuple[ms.Tensor], Tuple[ms.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" + # Specialized constructor for DDP cache data, needed for BC + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[ms.Tensor, ms.Tensor]]] = None, *args, **kwargs): + super().__init__(layer_classes=DynamicLayer, *args, **kwargs) + # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) + + def to_legacy_cache(self) -> tuple[tuple[ms.Tensor, ms.Tensor], ...]: + """ + Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility. + """ legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + for layer in self.layers: + legacy_cache += ((layer.keys, layer.values),) return legacy_cache @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" + def from_legacy_cache(cls, past_key_values: tuple[tuple[ms.Tensor, ms.Tensor], ...]) -> "Cache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): @@ -543,67 +1128,35 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = cache.update(key_states, value_states, layer_idx) return cache - def crop(self, max_length: int): - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) +class StaticCache(Cache): + """ + Static Cache class to be used with `mindspore.jit(model)`. - if self.get_seq_length() <= max_length: - return + See `Cache` for details on common methods that are implemented by all cache classes. - self._seen_tokens = max_length - for idx in range(len(self.key_cache)): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + Example: - def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - out = [] - for i in range(0, full_batch_size, split_size): - current_split = DynamicCache() - current_split._seen_tokens = self._seen_tokens - current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] - current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] - out.append(current_split) - return out + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import AutoModelForCausalLM, StaticCache - @classmethod - def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - cache = cls() - for idx in range(len(splits[0])): - layer_keys = ops.cat([current.key_cache[idx] for current in splits], dim=0) - layer_values = ops.cat([current.value_cache[idx] for current in splits], dim=0) - cache.update(layer_keys, layer_values, idx) - return cache + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = ops.repeat_interleave(self.key_cache[layer_idx], repeats, dim=0) - self.value_cache[layer_idx] = ops.repeat_interleave(self.value_cache[layer_idx], repeats, dim=0) + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") - def batch_select_indices(self, indices: ms.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ - def get_mask_sizes(self, cache_position: ms.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, 0 + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=StaticLayer, *args, **kwargs) class SlidingWindowCache(StaticCache): @@ -631,13 +1184,14 @@ class SlidingWindowCache(StaticCache): smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + dtype (`ms.Type`, *optional*, defaults to `ms.float32`): The default `dtype` to use when initializing the layer. Example: ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import AutoModelForCausalLM, SlidingWindowCache >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") @@ -705,45 +1259,128 @@ def update( # into consideration when building kv cache instead of just throwing away tokens outside of the window return key_states, value_states - slicing = ops.ones(self.max_cache_len, dtype=ms.int32).cumsum(0) - cache_position = cache_position.clamp(0, self.max_cache_len - 1) - to_shift = cache_position >= self.max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len + slicing = ops.ones(self.max_cache_len, dtype=ms.int32).cumsum(0) + cache_position = cache_position.clamp(0, self.max_cache_len - 1) + to_shift = cache_position >= self.max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len + + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx]) + self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx]) + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + + return k_out, v_out + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + def reset(self): + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx]) + self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx]) + + +class SlidingWindowCache(Cache): + """ + Sliding Window Cache class to be used with `mindspore.jit` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + + See `Cache` for details on common methods that are implemented by all cache classes. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SlidingWindowCache() + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=SlidingWindowLayer, *args, **kwargs) - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `mindspore.jit` for models that alternate between a local sliding window + attention and global attention in every other layer (originally implemented for Gemma2). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention. For more information, see the documentation of those layer types. - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx]) - self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx]) + See `Cache` for details on common methods that are implemented by all cache classes. - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out + Example: - return k_out, v_out + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import AutoModelForCausalLM, HybridCache - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - def reset(self): - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx]) - self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx]) + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + def __init__(self, config: PretrainedConfig, *args, **kwargs): + if hasattr(config, "layer_types"): + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + else: + # In this case, fall back to StaticCache + layer_classes = [StaticLayer] * config.num_hidden_layers + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) class EncoderDecoderCache(Cache): """ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and cross-attention caches. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python - >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache + >>> from mindone.transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") @@ -761,27 +1398,43 @@ class EncoderDecoderCache(Cache): """ + # Override @property from Cache + is_compileable = None + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): - super().__init__() + super().__init__(layer_classes=DynamicLayer) self.self_attention_cache = self_attention_cache self.cross_attention_cache = cross_attention_cache self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) self.is_updated = {} - for layer_idx in range(len(cross_attention_cache.key_cache)): + for layer_idx in range(len(cross_attention_cache)): self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) - def __getitem__(self, layer_idx: int) -> List[Tuple[ms.Tensor]]: + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield ( + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, + ) + + def __getitem__(self, layer_idx: int) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the sequence length. """ if layer_idx < len(self): return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, ) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") @@ -793,8 +1446,8 @@ def __len__(self): """ return len(self.self_attention_cache) - def to_legacy_cache(self) -> Tuple[Tuple[ms.Tensor], Tuple[ms.Tensor]]: - """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + def to_legacy_cache(self) -> tuple[tuple[ms.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" legacy_cache = () if len(self.cross_attention_cache) > 0: for self_attn, cross_attn in zip( @@ -806,7 +1459,9 @@ def to_legacy_cache(self) -> Tuple[Tuple[ms.Tensor], Tuple[ms.Tensor]]: return legacy_cache @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None) -> "EncoderDecoderCache": + def from_legacy_cache( + cls, past_key_values: tuple[tuple[ms.Tensor, ms.Tensor], ...] + ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" cache = cls( self_attention_cache=DynamicCache(), @@ -822,10 +1477,10 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = cache.is_updated[layer_idx] = True return cache - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position=None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - return self.self_attention_cache.get_seq_length(layer_idx) + # check if empty list because in case of static cache it will be a tensors and we can't check `if not ms.Tensor` + return self.self_attention_cache.get_seq_length(layer_idx, cache_position) def reset(self): if hasattr(self.self_attention_cache, "reset"): @@ -849,7 +1504,8 @@ def reorder_cache(self, beam_idx: ms.Tensor): def check_dynamic_cache(self, method: str): if not ( - isinstance(self.self_attention_cache, DynamicCache) and isinstance(self.cross_attention_cache, DynamicCache) + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) ): raise ValueError( f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " @@ -858,14 +1514,18 @@ def check_dynamic_cache(self, method: str): # TODO(gante, sanchit-gandhi): move following functionality into `.generate` def crop(self, maximum_length: int): - """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + """ + Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search. + """ self.check_dynamic_cache(self.crop.__name__) self.self_attention_cache.crop(maximum_length) - def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" + def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": + """ + Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils` + """ self.check_dynamic_cache(self.batch_split.__name__) self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) @@ -875,22 +1535,6 @@ def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDec out.append(EncoderDecoderCache(self_attn, cross_attn)) return out - @classmethod - def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - for idx in range(len(splits[0])): - layer_keys = ops.cat([current.self_attention_cache.key_cache[idx] for current in splits], axis=0) - layer_values = ops.cat([current.self_attention_cache.value_cache[idx] for current in splits], axis=0) - self_attention_cache.update(layer_keys, layer_values, idx) - - layer_keys = ops.cat([current.cross_attention_cache.key_cache[idx] for current in splits], axis=0) - layer_values = ops.cat([current.cross_attention_cache.value_cache[idx] for current in splits], axis=0) - cross_attention_cache.update(layer_keys, layer_values, idx) - return cls(self_attention_cache, cross_attention_cache) - def batch_repeat_interleave(self, repeats: int): """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" self.check_dynamic_cache(self.batch_repeat_interleave.__name__) @@ -903,204 +1547,213 @@ def batch_select_indices(self, indices: ms.Tensor): self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + return self.self_attention_cache.get_max_cache_shape() -class HybridCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention - and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention - and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. - - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - batch_size (`int`): - The batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - dtype (torch.dtype, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - - Example: + def get_mask_sizes(self, cache_position: ms.Tensor, layer_idx: int) -> tuple[int, int]: + return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache +class MambaCache: + def __init__(self): + raise NotImplementedError - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") +class OffloadedStaticCache(StaticCache): + def __init__(self): + raise NotImplementedError - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` +def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]: """ + Parse processor arguments from kwargs based on the processor class init signature. - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert - # ALL changes from the PR that commented the line below when reactivating it. - # is_compileable = True - - # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. - def __init__( - self, - config: PretrainedConfig, - batch_size: int = None, - max_cache_len: int = None, - dtype: ms.Type = ms.float32, - max_batch_size: Optional[int] = None, - ) -> None: - super().__init__() - if batch_size is not None: - logger.warning_once( - f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.49. Use the more precisely named 'max_batch_size' argument instead." - ) - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - self.max_cache_len = max_cache_len - self.max_batch_size = batch_size or max_batch_size - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - - self.dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads - ) + Args: + processor_class: The processor class to inspect, or None + kwargs: Dictionary of keyword arguments - layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC - self.is_sliding = ms.tensor( - [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=ms.bool_ + Returns: + tuple: (processor_kwargs, remaining_kwargs) + """ + try: + params = list(inspect.signature(processor_class.__init__).parameters)[2:] + except Exception: + return {}, kwargs + + processor_kwargs = {k: kwargs[k] for k in params if k in kwargs} + remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs} + return processor_kwargs, remaining_kwargs + + +def parse_layer_args_from_model_config( + config: Optional[PretrainedConfig], + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + dtype: Optional[ms.Type] = None, + tp_size: Optional[int] = None, + max_batch_size: Optional[int] = None, +) -> dict: + """ + Parse layer arguments from model configuration for cache initialization. + + Args: + config (`Optional[PretrainedConfig]`): Model configuration containing shape/device info. + batch_size (`Optional[int]`): Batch size for cache initialization. + max_cache_len (`Optional[int]`): Maximum sequence length for cache. + dtype (`Optional[ms.Type]`): Data type for cache tensors. + tp_size (`Optional[int]`): Tensor parallel size to adjust number of key/value heads. + max_batch_size (`Optional[int]`): Maximum batch size for cache initialization. + + Returns: + `dict`: Dictionary containing parsed layer arguments for cache initialization. + """ + # No model config -> must be a dynamic cache, return bare dict + if config is None: + return {} + # Build the args dict for hybrid, sliding or static + else: + # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) + if ( + getattr(config, "layer_types", None) is not None + and "sliding_attention" in config.layer_types + and "full_attention" in config.layer_types + ): + if getattr(config, "sliding_window", None) is None: + raise ValueError( + "Setting up a hybrid or sliding window KVCache requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) + max_cache_len = max_cache_len or config.max_position_embeddings + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: + head_dim = ( + config.head_dim + if getattr(config, "head_dim", None) is not None + else config.hidden_size // config.num_attention_heads ) - self.key_cache: List[ms.Tensor] = [] - self.value_cache: List[ms.Tensor] = [] - global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) - sliding_cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - min(config.sliding_window, max_cache_len), - self.head_dim, + num_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads ) - for i in range(config.num_hidden_layers): - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape - new_layer_key_cache = ops.zeros(cache_shape, dtype=self.dtype) - new_layer_value_cache = ops.zeros(cache_shape, dtype=self.dtype) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - if cache_position.shape[0] > max_cache_len: - k_out = key_states[:, :, -max_cache_len:, :] - v_out = value_states[:, :, -max_cache_len:, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = ops.ones(max_cache_len, dtype=ms.int32).cumsum(0) - cache_position = cache_position.clamp(0, max_cache_len - 1) - to_shift = cache_position >= max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % max_cache_len - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx]) - self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx]) + if tp_size is not None and tp_size > 1: + if num_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + num_heads //= tp_size + layer_args = { + "batch_size": max_batch_size if max_batch_size is not None else batch_size, + "max_cache_len": max_cache_len, + "dtype": dtype, + "head_dim": head_dim, + "num_heads": num_heads, + "sliding_window": getattr(config, "sliding_window", None), + } + return {k: v for k, v in layer_args.items() if v is not None} + +LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { + "full_attention": StaticLayer, + "sliding_attention": SlidingWindowLayer, + "chunked_attention": ChunkedSlidingLayer, +} +PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { + "offloaded": OffloadedCacheProcessor, + "quanto_quantized": QuantizedCacheProcessor, + "hqq_quantized": HQQQuantizedCacheProcessor, +} - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - return k_out, v_out +class CacheConfig: + """ + Base class for cache configs + """ - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + cache_implementation: None - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config - def update( - self, - key_states: ms.Tensor, - value_states: ms.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[ms.Tensor]: - cache_position = cache_kwargs.get("cache_position") - sliding_window = cache_kwargs.get("sliding_window") + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - if sliding_window: - update_fn = self._sliding_update - else: - update_fn = self._static_update - - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) + writer.write(json_string) - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) - def get_seq_length(self, layer_idx: Optional[int] = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx]) - self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx]) + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" - @property - def batch_size(self): - logger.warning_once( - f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " - "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." - ) - return self.max_batch_size + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. -class MambaCache: - def __init__(self): - raise NotImplementedError + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) -class OffloadedStaticCache(StaticCache): - def __init__(self): - raise NotImplementedError + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs \ No newline at end of file From 2f70121b04512bc4f20adf8f6958eb9b08b7dac8 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:40:09 +0800 Subject: [PATCH 05/94] feat(transformers): upgrade modeling_utils to v4.54 --- mindone/transformers/modeling_utils.py | 879 ++++++++++++++++++++----- mindone/transformers/utils/generic.py | 16 + 2 files changed, 717 insertions(+), 178 deletions(-) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index c62042197c..01de7a138a 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -21,10 +21,11 @@ import json import os import re +import sys import warnings from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union, get_type_hints from transformers.configuration_utils import PretrainedConfig from transformers.dynamic_module_utils import custom_object_save @@ -67,6 +68,7 @@ from .integrations.flash_attention import flash_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward from .loss.loss_utils import LOSS_MAPPING +from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from .mindspore_adapter import dtype_to_str from .mindspore_utils import ( # noqa: F401 Conv1D, @@ -78,6 +80,7 @@ ) from .modeling_attn_mask_utils import dtype_to_min from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available +from .utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder if is_safetensors_available(): from safetensors import safe_open @@ -85,6 +88,27 @@ # from mindone.safetensors.mindspore import load_file as safe_load_file from mindone.safetensors.mindspore import save_file as safe_save_file +# DO NOT MODIFY, KEPT FOR BC ONLY +VLMS = [ + "aria", + "ayavision", + "colpali", + "emu3", + "fuyu", + "gotocr2", + "gemma3", + "internvl", + "llava", # all llava prefixed models fall under this check + "mistral3", + "mllama", + "paligemma", + "shieldgemma2", + "qwen2vl", + "qwen2_5_vl", + "videollava", + "vipllava", +] + logger = logging.get_logger(__name__) _init_weights = True @@ -372,6 +396,97 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name +def _get_mindspore_dtype( + cls, + mindspore_dtype: Optional[Union[str, ms.Type, dict]], + checkpoint_files: Optional[list[str]], + config: PretrainedConfig, + sharded_metadata: Optional[dict], + state_dict: Optional[dict], + weights_only: bool, + is_sharded: bool, +): + # set dtype to instantiate the model under: + # 1. If mindspore_dtype is not None, we use that dtype + # 2. If mindspore_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + + if mindspore_dtype is not None: + config.mindspore_dtype = dtype_to_str(mindspore_dtype) + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.mindspore_dtype = mindspore_dtype + if isinstance(mindspore_dtype, str): + if mindspore_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + mindspore_dtype = config.torch_dtype + logger.info(f"Will use dtype={mindspore_dtype} as defined in model's config object") + else: + if is_sharded and "dtype" in sharded_metadata: + mindspore_dtype = sharded_metadata["dtype"] + elif not is_sharded: + mindspore_dtype = get_state_dict_dtype(state_dict) + else: + one_state_dict = load_state_dict(checkpoint_files[0]) + mindspore_dtype = get_state_dict_dtype(one_state_dict) + del one_state_dict # free CPU memory + logger.info( + f"Since the `torch_dtype` attribute can't be found in model's config object, " + f"will use dtype={mindspore_dtype} as derived from model's weights" + ) + else: + raise ValueError( + f'`mindspore_dtype` can be either `ms.Type` or `"auto"`, but received {mindspore_dtype}' + ) + # TODO: We cannot set default mindspore dtype! + return config, mindspore_dtype + +def _find_missing_and_unexpected_keys( + cls, + model: "PreTrainedModel", + original_checkpoint_keys: List[str], + checkpoint_keys: List[str], + loading_base_model_from_task_state_dict: bool, +) -> Tuple[List[str], List[str]]: + """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys + (keys found in the loaded state dict keys, but that are NOT part of the model parameters) + """ + prefix = model.base_model_prefix + + # Compute expected keys, i.e. keys that the FULL model (not model_to_load) expects + expected_keys = list(model.state_dict().keys()) + + # Adjust prefix of the keys to make them match loaded keys before removing them + missing_keys = sorted(set(expected_keys) - set(checkpoint_keys)) + unexpected_keys = set(checkpoint_keys) - set(expected_keys) + # If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys + if loading_base_model_from_task_state_dict: + task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")] + unexpected_keys.update(task_specific_keys) + + # Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but + # may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway + model_buffers = {n for n, _ in model.named_buffers()} + unexpected_keys = sorted(unexpected_keys - model_buffers) + + # Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model + # (so the buffer name has changed). Remove them in such a case + has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers) + if has_inv_freq_buffers: + unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] + + # Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...) + if cls._keys_to_ignore_on_load_missing is not None: + for pattern in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pattern, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pattern in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None] + + return missing_keys, unexpected_keys + class ModuleUtilsMixin: """ @@ -382,10 +497,6 @@ def _get_name(self): return self.__class__.__name__ def to(self, dtype: Optional[ms.Type] = None): - # FIXME: In ms 2.6.0 `tensor.set_dtype()` encountered a bug that it occurs wrong values. - # Resume to use self.register_buffer() in network and set dtype for buffer tensors after ms2.7.0 launched. - # Now we use `Parameter` and `Parameter.set_dtype()` instead. - for p in self.get_parameters(): p.set_dtype(dtype) return self @@ -617,8 +728,96 @@ def floating_point_ops(self, input_dict: Dict[str, Union[ms.Tensor, Any]], exclu return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) +class EmbeddingAccessMixin: + """ + Base utilities to regroup getters and setters for embeddings. + Introduces the `input_layer_embed` attribute, which indicates + where the input embeddings come from and where they + should be set. + """ + + _input_embed_layer = "embed_tokens" # default layer that holds input embeddings. -class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): + def get_input_embeddings(self) -> nn.Cell: + """ + Returns the model's input embeddings. + + Returns: + `nn.Cell`: A mindspore module mapping vocabulary to hidden states. + """ + + # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer + # for most NLP models), and if so, return it. + + name = getattr(self, "_input_embed_layer", "embed_tokens") + + if (default_embedding := getattr(self, name, None)) is not None: + return default_embedding + # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration` + + if hasattr(self, "model") and hasattr(self.model, "embed_tokens"): + return self.model.embed_tokens + + # 3) vanilla decoder‑only architectures + elif hasattr(self, "embed_tokens"): + return self.embed_tokens + else: + base_model = getattr(self, "base_model_prefix", None) + if base_model is not None: + base_model = getattr(self, base_model, None) + if base_model is not None and base_model is not self: + return base_model.get_input_embeddings() + raise NotImplementedError( + f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; " + "please override in the subclass." + ) + + def set_input_embeddings(self, value: nn.Cell): + """Fallback setter that handles **~70 %** of models in the code‑base. + + Order of attempts: + 1. `self.model.embed_tokens` + 2. `self.embed_tokens` + 3. delegate to the *base model* if one exists + 4. otherwise raise `NotImplementedError` so subclasses still can (and + should) override for exotic layouts. + """ + + # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration` + name = getattr(self, "_input_embed_layer", "embed_tokens") + if hasattr(self, "model") and hasattr(self.model, name): + setattr(self.model, name, value) + # 2) as well as vanilla decoder‑only architectures + elif hasattr(self, name): + setattr(self, name, value) + # 3) recurse once into the registered *base* model (e.g. for encoder/decoder) + elif getattr(self, self.base_model_prefix, self) is not self: + base_model = getattr(self, self.base_model_prefix, self) + base_model.set_input_embeddings(value) + else: + raise NotImplementedError( + f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass." + ) + + def get_output_embeddings(self): + if not hasattr(self, "lm_head"): + return None + try: + # Speech / vision backbones raise here, so we return None. + # Legit use of get_input_embs? + self.get_input_embeddings() + except NotImplementedError: + return None + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + """ + Sets the model's output embedding, defaulting to setting new_embeddings to lm_head. + """ + if getattr(self, "lm_head"): + self.lm_head = new_embeddings + +class PreTrainedModel(nn.Cell, EmbeddingAccessMixin, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): r""" Base class for all models. @@ -651,10 +850,16 @@ class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin main_input_name = "input_ids" model_tags = None + _checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models + _auto_class = None _no_split_modules = None _skip_keys_device_placement = None + _keep_in_fp32_modules = None + # the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16 + # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag + _keep_in_fp32_modules_strict = None # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. @@ -672,8 +877,8 @@ class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin is_parallelizable = False supports_gradient_checkpointing = False - # Flash Attention 2 support - _supports_flash_attn_2 = False + # Flash Attention support + _supports_flash_attn = False # SDPA support _supports_sdpa = False @@ -692,6 +897,49 @@ class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin # In practice, it means that they support attention interface functions, fully pass the kwargs # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan _supports_attention_backend = False + _can_record_outputs = None + + @property + def can_record_outputs(self) -> dict[str, OutputRecorder]: + """ + Maps output names (e.g., "attentions", "hidden_states") + to either: + - A module class (e.g., `LlamaDecoderLayer`), using default index conventions: + * index=0 for "hidden_states" + * index=1 for "attentions" + - Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`. + + Examples: + These two are equivalent: + + ```python + _can_record_outputs = { + "attentions": LlamaAttention, + "hidden_states": LlamaDecoderLayer + } + + _can_record_outputs = { + "attentions": OutputRecorder(LlamaAttention, index=1), + "hidden_states": OutputRecorder(LlamaDecoderLayer, index=0) + } + ``` + + This means you can record outputs from the same class, by specifying a layer name. Before + collecting outputs, we check that they come from this layer. + + If you have cross attention that come from `LlamaAttention` and self attention that also + come from `LlamaAttention` but from `self_attn` you can do this: + + ```python + class LlamaModel(PreTrainedModel): + _can_record_outputs = { + "attentions": OutputRecorder(LlamaAttention, index=1, layer-name="self_attn"), + "cross_attentions": OutputRecorder(LlamaAttention, index=1, layer_name="cross_attn") + } + + ``` + """ + return self._can_record_outputs or {} @property def dummy_inputs(self) -> Dict[str, Tensor]: @@ -707,6 +955,30 @@ def framework(self) -> str: """ return "ms" + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # For BC we keep the original `config_class` definition in case + # there is a `config_class` attribute (e.g. remote code models), + # otherwise we derive it from the annotated `config` attribute. + + # defined in this particular subclass + child_annotation = cls.__dict__.get("__annotations__", {}).get("config", None) + child_attribute = cls.__dict__.get("config_class", None) + + # defined in the class (this subclass or any parent class) + full_annotation = get_type_hints(cls).get("config", None) + full_attribute = cls.config_class + + # priority (child class_config -> child annotation -> global class_config -> global annotation) + if child_attribute is not None: + cls.config_class = child_attribute + elif child_annotation is not None: + cls.config_class = child_annotation + elif full_attribute is not None: + cls.config_class = full_attribute + elif full_annotation is not None: + cls.config_class = full_annotation + def __init__(self, config: PretrainedConfig, *inputs, **kwargs): super().__init__() if not isinstance(config, PretrainedConfig): @@ -717,6 +989,13 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): ) # Save config and origin of the pretrained weights if given in model self.config = config + + # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid + # setting it recursively) + self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation( + self.config._attn_implementation, is_init_check=True + ) + self.name_or_path = config.name_or_path self.warnings_issued = {} self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None @@ -724,6 +1003,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute # when a different component (e.g. language_model) is used. self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) + self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict) + + self._no_split_modules = self._no_split_modules or [] + _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only def post_init(self): """ @@ -732,69 +1015,6 @@ def post_init(self): """ self.init_weights() - @classmethod - def _autoset_attn_implementation( - cls, - config, - use_flash_attention_2: bool = False, - mindspore_dtype=None, - ): - """ - Automatically checks and dispatches to a default attention implementation. In order of priority: - 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). - 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) - 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) - 4. The default model's implementation otherwise (`LlamaAttention` for example) . - """ - # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. - # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). - # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) - requested_attn_implementation = None - if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: - if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: - raise ValueError( - f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were ' - f"used when loading the model, which are not compatible." - ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' - ) - - if config._attn_implementation not in ["eager", "paged_attention"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): - message = ( - f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. ' - f'The only possible arguments are `attn_implementation="eager"`' - f" (manual attention implementation)" - ) - if cls._supports_flash_attn_2: - message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' - if cls._supports_sdpa: - message += ', `"attn_implementation=sdpa"` (implementation using scaled_dot_product_attention)' - raise ValueError(message + ".") - - # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the - # user-provided config, with hard checks that the requested attention implementation is available. - requested_attn_implementation = config._attn_implementation_internal - - if use_flash_attention_2: - logger.warning_once( - "The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a " - 'future release. Please use `attn_implementation="flash_attention_2"` instead.' - ) - config._attn_implementation = "flash_attention_2" - if config._attn_implementation == "flash_attention_2": - cls._check_and_enable_flash_attn_2( - config, - mindspore_dtype=mindspore_dtype, - hard_check_only=False, - ) - elif requested_attn_implementation in [None, "sdpa"]: - # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. - config = cls._check_and_enable_sdpa( - config, - hard_check_only=False if requested_attn_implementation is None else True, - ) - - return config - @property def base_model(self) -> nn.Cell: """ @@ -822,12 +1042,13 @@ def can_generate(cls) -> bool: continue if "PreTrainedModel" not in str(base) and base.can_generate(): return True - # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this + + # Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this # was how we detected whether a model could generate. - if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): - logger.warning_once( + if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin` + logger.warning( f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " - "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " + "defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " "to call `generate` and other related functions." "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " @@ -837,7 +1058,6 @@ def can_generate(cls) -> bool: "\n - If you are not the owner of the model architecture class, please contact the model code owner " "to update it." ) - return True # Otherwise, can't generate return False @@ -963,20 +1183,9 @@ def _from_config(cls, config, **kwargs): config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. - if config._attn_implementation_internal is not None: - # In this case, the config has been created with the attn_implementation set by the user, which we - # should respect. - attn_implementation = config._attn_implementation_internal - else: - attn_implementation = None - - config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation) - if not getattr(config, "_attn_implementation_autoset", False): - config = cls._autoset_attn_implementation( - config, - use_flash_attention_2=use_flash_attention_2, - mindspore_dtype=mindspore_dtype, - ) + # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs) + if "attn_implementation" in kwargs: + config._attn_implementation = kwargs.pop("attn_implementation") model = cls(config, **kwargs) @@ -1025,6 +1234,255 @@ def get_output_embeddings(self) -> nn.Cell: """ return None # Overwrite for models with output embeddings + def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: + """ + Check the availability of Flash Attention 2 for a given model. + + Args: + is_init_check (`bool`, *optional*): + Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are + fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses + BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early + before instantiating the full models if we know that the model does not support the requested attention. + """ + mindspore_dtype = self.config.torch_dtype + if isinstance(mindspore_dtype, str): + mindspore_dtype = getattr(ms, mindspore_dtype) + elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): + TORCH_TO_MINDSPORE_DTYPE_MAP = { + "torch.float32": ms.float32, + "torch.bfloat16": ms.bfloat16, + "torch.float16": ms.float16, + } + mindspore_dtype = str(mindspore_dtype) + mindspore_dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[mindspore_dtype] + + # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases + if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)): + raise ValueError( + f"{self.__class__.__name__} does not support Flash Attention 2.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_2_available(): + preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." + + if mindspore_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif mindspore_dtype is not None and mindspore_dtype not in [ms.float16, ms.bfloat16]: + logger.warning_once( + "Flash Attention 2 only supports ms.float16 and ms.bfloat16 dtypes, but" + f" the current dype in {self.__class__.__name__} is {mindspore_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`' + ) + + # With the early check, the parameters are not yet initalized correctly + if not is_init_check: + if getattr(self, "use_bettertransformer", False): + raise ValueError( + "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + # If no error raise by this point, we can return `True` + return True + + def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool: + """ + Check the availability of SDPA for a given model. + + Args: + is_init_check (`bool`, *optional*): + Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are + fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses + BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early + before instantiating the full models if we know that the model does not support the requested attention. + """ + if not self._supports_sdpa: + raise ValueError( + f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" + ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_sdpa_available(): + raise ImportError("MindSpore SDPA requirements in Transformers are not met.") + + return True + + def _check_and_adjust_attn_implementation( + self, attn_implementation: Optional[str], is_init_check: bool = False + ) -> str: + """ + Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if + it matches hf kernels pattern. + + Args: + attn_implementation (`str` or `None`): + The attention implementation to check for existence/validity. + is_init_check (`bool`, *optional*): + Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are + fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses + BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early + before instantiating the full models if we know that the model does not support the requested attention. + + Returns: + `str`: The final attention implementation to use, including potential fallbacks from sdpa to eager, or from + None to sdpa (to potentially eager). + """ + applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation + if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation): + + # Extract repo_id and kernel_name from the string + if ":" in applicable_attn_implementation: + repo_id, kernel_name = attn_implementation.split(":") + kernel_name = kernel_name.strip() + else: + repo_id = attn_implementation + kernel_name = None + repo_id = repo_id.strip() + try: + # fixme there is no implementation for kernel in mindspore + ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) + applicable_attn_implementation = repo_id + except Exception as e: + logger.warning_once( + f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " + "default attention implementation instead (sdpa if available, eager otherwise)." + ) + applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case + if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): + message = ( + f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are ' + '`attn_implementation="eager"` (manual attention implementation)' + ) + # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases + if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False): + message += ( + ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' + ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + ) + if self._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + if self._supports_flex_attn: + message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' + raise ValueError(message + ".") + + # Perform relevant checks + if applicable_attn_implementation == "flash_attention_2": + self._flash_attn_2_can_dispatch(is_init_check) + elif applicable_attn_implementation == "flash_attention_3": + self._flash_attn_3_can_dispatch(is_init_check) + elif applicable_attn_implementation == "flex_attention": + self._flex_attn_can_dispatch(is_init_check) + elif applicable_attn_implementation == "sdpa": + # Sdpa is the default, so we try it and fallback to eager otherwise when not possible + try: + self._sdpa_can_dispatch(is_init_check) + except (ValueError, ImportError) as e: + # In this case, sdpa was requested explicitly, but we can't use it, so let's raise + if attn_implementation == "sdpa": + raise e + applicable_attn_implementation = "eager" + + return applicable_attn_implementation + + @classmethod + def _can_set_attn_implementation(cls) -> bool: + """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on + opening the file, but avoids maintaining yet another property flag. + """ + class_file = sys.modules[cls.__module__].__file__ + with open(class_file, "r") as f: + code = f.read() + # heuristic -> if we find those patterns, the model uses the correct interface + return ( + "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code + ) + + def set_attn_implementation(self, attn_implementation: Union[str, dict]): + """ + Set the requested `attn_implementation` for this model. + + Args: + attn_implementation (`str` or `dict`): + The attention implementation to set for this model. It can be either a `str`, in which case it will be + dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each + submodel will dispatch the corresponding value. + """ + requested_implementation = ( + attn_implementation + if not isinstance(attn_implementation, dict) + else attn_implementation.get("", self.config._attn_implementation) + ) + + # At this point, the model was already instantiated, so instead of crashing on bad value, let's simply + # warn the user that the requested value is not working + if requested_implementation != self.config._attn_implementation: + # In this case, raise + if not self._can_set_attn_implementation(): + logger.warning( + f"{self.__class__.__name__} does not support setting its attention implementation dynamically, because it " + "does not follow the functional approach based on AttentionInterface " + "(see https://huggingface.co/docs/transformers/en/attention_interface)" + ) + else: + try: + applicable_attn_implementation = self._check_and_adjust_attn_implementation( + requested_implementation, is_init_check=False + ) + # Apply the change (on the internal attr, to avoid setting it recursively) + self.config._attn_implementation_internal = applicable_attn_implementation + except (ValueError, ImportError) as e: + logger.warning( + f"Impossible to set the requested `attn_implementation`. The following error was captured: {str(e)}" + ) + + subconfigs_changed = set() + # Apply it to all submodels as well + for submodule in self.modules(): + # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model", + # e.g. ForCausalLM has a Model inside, but no need to check it again) + if ( + submodule is not self + and isinstance(submodule, PreTrainedModel) + and submodule.config.__class__ != self.config.__class__ + ): + sub_implementation = attn_implementation + if isinstance(attn_implementation, dict): + for subconfig_key in self.config.sub_configs: + # We need to check for exact object match here, with `is` + if getattr(self.config, subconfig_key) is submodule.config: + sub_implementation = attn_implementation.get( + subconfig_key, submodule.config._attn_implementation + ) + break + submodule.set_attn_implementation(sub_implementation) + subconfigs_changed.add(submodule.config.__class__) + + # We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel + for subconfig_key in self.config.sub_configs: + subconfig = getattr(self.config, subconfig_key) + requested_implementation = ( + attn_implementation + if not isinstance(attn_implementation, dict) + else attn_implementation.get(subconfig_key, subconfig._attn_implementation) + ) + # This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered + if ( + subconfig.__class__ not in subconfigs_changed + and requested_implementation != subconfig._attn_implementation + and requested_implementation in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() + ): + subconfig._attn_implementation_internal = requested_implementation + logger.warning( + f"We set the attention implementation for the sub-config `{subconfig_key}` to `{requested_implementation}` " + "without finding the associated sub-model. For this reason we could not check if the model supports it. " + "You may encounter undefined behavior." + ) + def _init_weights(self, module): """ Initialize the weights. This method should be overridden by derived class and is @@ -1717,6 +2175,7 @@ def from_pretrained( token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, + weights_only: bool = True, **kwargs, ): r""" @@ -2298,40 +2757,10 @@ def from_pretrained( # Time to load the checkpoint state_dict = load_state_dict(resolved_archive_file) - # set dtype to instantiate the model under: - # 1. If mindspore_dtype is not None, we use that dtype - # 2. If mindspore_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first - # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype - # we also may have config.torch_dtype available, but we won't rely on it till v5 - - if mindspore_dtype is not None: - config.mindspore_dtype = dtype_to_str(mindspore_dtype) - for sub_config_key in config.sub_configs.keys(): - sub_config = getattr(config, sub_config_key) - sub_config.mindspore_dtype = mindspore_dtype - if isinstance(mindspore_dtype, str): - if mindspore_dtype == "auto": - if hasattr(config, "torch_dtype") and config.torch_dtype is not None: - mindspore_dtype = config.torch_dtype - logger.info(f"Will use dtype={mindspore_dtype} as defined in model's config object") - else: - if is_sharded and "dtype" in sharded_metadata: - mindspore_dtype = sharded_metadata["dtype"] - elif not is_sharded: - mindspore_dtype = get_state_dict_dtype(state_dict) - else: - one_state_dict = load_state_dict(resolved_archive_file[0]) - mindspore_dtype = get_state_dict_dtype(one_state_dict) - del one_state_dict # free CPU memory - logger.info( - f"Since the `torch_dtype` attribute can't be found in model's config object, " - f"will use dtype={mindspore_dtype} as derived from model's weights" - ) - else: - raise ValueError( - f'`mindspore_dtype` can be either `ms.Type` or `"auto"`, but received {mindspore_dtype}' - ) - # TODO: We cannot set default mindspore dtype! + # Find the correct dtype based on current state + config, mindspore_dtype = _get_mindspore_dtype( + cls, mindspore_dtype, resolved_archive_file, config, sharded_metadata, state_dict, weights_only, is_sharded + ) # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (mindspore_dtype == ms.float16) @@ -2344,9 +2773,6 @@ def from_pretrained( config.name_or_path = pretrained_model_name_or_path config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. - config = cls._autoset_attn_implementation( - config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype - ) model = cls(config, *model_args, **model_kwargs) @@ -2442,6 +2868,83 @@ def from_pretrained( return model + @staticmethod + def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]: + """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" + # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) + # This rename is logged. + if key.endswith("LayerNorm.beta"): + return key.replace("LayerNorm.beta", "LayerNorm.bias"), True + if key.endswith("LayerNorm.gamma"): + return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True + + return key, False + + def _get_key_renaming_mapping( + self, + checkpoint_keys: List[str], + key_mapping: Optional[Dict[str, str]] = None, + loading_base_model_from_task_state_dict: bool = False, + loading_task_model_from_base_state_dict: bool = False, + ): + """ + Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model + that we are loading expects. This is the single entry point for key renaming that will be used during + loading. + Log if any parameters have been renamed. + """ + prefix = self.base_model_prefix + _prefix = f"{prefix}." + + renamed_keys = {} + key_renaming_mapping = {} + for key in checkpoint_keys: + # Class specific rename + new_key, has_changed = self._fix_state_dict_key_on_load(key) + + # Optionally map the key according to `key_mapping` + if key_mapping is not None: + for pattern, replacement in key_mapping.items(): + new_key, n_replace = re.subn(pattern, replacement, new_key) + # Early exit of the loop + if n_replace > 0: + has_changed = True + break + + # In this case, we need to add the prefix to the keys, to match them to the expected keys + if loading_task_model_from_base_state_dict: + new_key = ".".join([prefix, new_key]) + key = ".".join([prefix, key]) + # In this case we need to remove the prefix from the key to match them to the expected keys, and use + # only the keys starting with the prefix + elif loading_base_model_from_task_state_dict: + if not new_key.startswith(_prefix): + continue + new_key = new_key[len(_prefix) :] + key = key[len(_prefix) :] + + if not has_changed: + key_renaming_mapping[new_key] = new_key + else: + key_renaming_mapping[key] = new_key + + # track gamma/beta rename for logging + if has_changed: + if key.endswith("LayerNorm.gamma"): + renamed_keys["LayerNorm.gamma"] = (key, new_key) + elif key.endswith("LayerNorm.beta"): + renamed_keys["LayerNorm.beta"] = (key, new_key) + + if renamed_keys: + warning_msg = f"A pretrained model of type `{self.__class__.__name__}` " + warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" + for old_key, new_key in renamed_keys.values(): + warning_msg += f"* `{old_key}` -> `{new_key}`\n" + warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." + logger.info_once(warning_msg) + + return key_renaming_mapping + @classmethod def _load_pretrained_model( cls, @@ -2454,50 +2957,46 @@ def _load_pretrained_model( sharded_metadata=None, dtype=None, keep_in_fp32_modules=None, + key_mapping: Optional[Dict[str, str]] = None, + weights_only: bool = True, ): - model.tie_weights() - - # Retrieve missing & unexpected_keys model_state_dict = {k: v for k, v in model.parameters_and_names()} - expected_keys = list(model_state_dict.keys()) prefix = model.base_model_prefix original_loaded_keys = loaded_keys - if len(prefix) > 0: - has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) - expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + # Get all the keys of the state dicts that we have to initialize the model + if sharded_metadata is not None: + original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"] + elif state_dict is not None: + original_checkpoint_keys = list(state_dict.keys()) else: - has_prefix_module = False - expects_prefix_module = False - - # Mapping loaded_keys from pt to ms - pt2ms_mappings = _get_pt2ms_mappings(model) - loaded_keys = _get_pt2ms_mapped_k(pt2ms_mappings, has_prefix_module, expects_prefix_module, loaded_keys, prefix) - - # key re-naming operations are never done on the keys - # that are loaded, but always on the keys of the newly initialized model - remove_prefix_from_model = not has_prefix_module and expects_prefix_module - add_prefix_to_model = has_prefix_module and not expects_prefix_module - - if remove_prefix_from_model: - _prefix = f"{prefix}." - expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] - expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] - elif add_prefix_to_model: - expected_keys = [".".join([prefix, s]) for s in expected_keys] - - missing_keys = sorted(set(expected_keys) - set(loaded_keys)) - unexpected_keys = set(loaded_keys) - set(expected_keys) - - # Some models may have keys that are not in the state by design, removing them before needlessly warning - # the user. - if cls._keys_to_ignore_on_load_missing is not None: - for pat in cls._keys_to_ignore_on_load_missing: - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + original_checkpoint_keys = list(load_state_dict(pretrained_model_name_or_path).keys()) + + # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture + prefix = model.base_model_prefix + _prefix = f"{prefix}." + has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False + expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False + loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module + loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module + + # Find the key names that the model expects from the serialized keys + key_renaming_mapping = model._get_key_renaming_mapping( + original_checkpoint_keys, + key_mapping, + loading_base_model_from_task_state_dict, + loading_task_model_from_base_state_dict, + ) + checkpoint_keys = list(key_renaming_mapping.values()) + + # Find missing and unexpected keys from the state dict + missing_keys, unexpected_keys = _find_missing_and_unexpected_keys( + cls, + model, + original_checkpoint_keys, + checkpoint_keys, + loading_base_model_from_task_state_dict, + ) # Set some modules to fp32 if any if keep_in_fp32_modules is not None: @@ -2505,20 +3004,32 @@ def _load_pretrained_model( if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): param.set_dtype(ms.float32) - # Make sure we are able to load base models as well as derived models (with heads) - start_prefix = "" + # Make sure we are able to load base models as well as derived models (specific task models, with heads) model_to_load = model - if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: - start_prefix = cls.base_model_prefix + "." - if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: - model_to_load = getattr(model, cls.base_model_prefix) - base_model_expected_keys = list(k for k, v in model_to_load.parameters_and_names()) - if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): + # In this case, we load a ForTaskModel with keys from a BaseModel -> only load keys to the BaseModel + if loading_task_model_from_base_state_dict: + model_to_load = getattr(model, prefix) + # Here we need to remove the prefix we added to correctly find missing/unexpected keys, as we will load + # in the submodule + key_renaming_mapping = {k: v[len(_prefix) :] for k, v in key_renaming_mapping.items()} + checkpoint_keys = list(key_renaming_mapping.values()) + # small sanity check: the base model should not contain task-specific head keys + task_specific_expected_keys = [s for s in model.state_dict().keys() if not s.startswith(_prefix)] + base_model_expected_keys = list(model_to_load.state_dict().keys()) + if any( + key in task_specific_expected_keys and key not in base_model_expected_keys for key in checkpoint_keys + ): raise ValueError( "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " "properly saved?" ) + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + def _find_mismatched_keys( state_dict, model_state_dict, @@ -2563,12 +3074,17 @@ def _find_mismatched_keys( # Whole checkpoint state_dict = _convert_state_dict(model, state_dict, prefix) + matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s] + if matching: + # Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta + state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} + mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, - add_prefix_to_model, - remove_prefix_from_model, + loading_task_model_from_base_state_dict, + loading_base_model_from_task_state_dict, ignore_mismatched_sizes, ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_sharded=False) @@ -2590,14 +3106,21 @@ def _find_mismatched_keys( state_dict = load_state_dict(shard_file) state_dict = _convert_state_dict(model, state_dict, prefix) + matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s] + if matching: + # Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta + state_dict = { + key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping + } + # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys += _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, - add_prefix_to_model, - remove_prefix_from_model, + loading_task_model_from_base_state_dict, + loading_base_model_from_task_state_dict, ignore_mismatched_sizes, ) @@ -3207,4 +3730,4 @@ def valid_keys(self) -> List[str]: ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() # for BC -MSPreTrainedModel = PreTrainedModel +MSPreTrainedModel = PreTrainedModel \ No newline at end of file diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index d571e4a53e..670454e373 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -34,6 +34,7 @@ from .import_utils import is_mindspore_available +_CAN_RECORD_REGISTRY = {} class cached_property(property): """ @@ -467,6 +468,21 @@ class TransformersKwargs(TypedDict, total=False): output_router_logits: Optional[bool] +class OutputRecorder: + """ + Configuration for recording outputs from a model via hooks. + + Attributes: + target_class (Type): The class (e.g., nn.Cell) to which the hook will be attached. + index (Optional[int]): If the output is a tuple/list, optionally record only at a specific index. + layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn". + """ + + target_class: "type[ms.nn.Cell]" + index: Optional[int] = 0 + layer_name: Optional[str] = None + + def filter_out_non_signature_kwargs(extra: Optional[list] = None): """ Decorator to filter out named arguments that are not in the function signature. From 44ad42494b74ac67eb7bbecaa647adc6ae0dd293 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:04:00 +0800 Subject: [PATCH 06/94] feat(transformers): upgrade generation/utils to v4.54 --- mindone/transformers/generation/utils.py | 576 ++++++++++++++++++----- mindone/transformers/modeling_utils.py | 66 ++- 2 files changed, 488 insertions(+), 154 deletions(-) diff --git a/mindone/transformers/generation/utils.py b/mindone/transformers/generation/utils.py index 5a7b345e18..d17a899226 100644 --- a/mindone/transformers/generation/utils.py +++ b/mindone/transformers/generation/utils.py @@ -16,6 +16,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import copy import inspect import time @@ -24,11 +25,18 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +from huggingface_hub import file_exists from packaging import version from transformers import logging -from transformers.generation.configuration_utils import GenerationConfig, GenerationMode +from transformers.generation.configuration_utils import CompileConfig, GenerationConfig, GenerationMode from transformers.tokenization_utils import ExtensionsTrie from transformers.utils.generic import ModelOutput +from transformers.dynamic_module_utils import ( + check_python_requirements, + get_cached_module_file, + get_class_in_module, + resolve_trust_remote_code, +) import mindspore as ms import mindspore.numpy as mnp @@ -134,33 +142,33 @@ class GenerateDecoderOnlyOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for + at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for + at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size, generated_length, hidden_size)`. past_key_values (`tuple(tuple(mindspore.Tensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + Usually a tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. """ - sequences: ms.Tensor = None - scores: Optional[Tuple[ms.Tensor]] = None - logits: Optional[Tuple[ms.Tensor]] = None - attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None - hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None - past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None + sequences: ms.Tensor + scores: Optional[tuple[ms.Tensor]] = None + logits: Optional[tuple[ms.Tensor]] = None + attentions: Optional[tuple[tuple[ms.Tensor]]] = None + hidden_states: Optional[tuple[tuple[ms.Tensor]]] = None + past_key_values: Optional[tuple[tuple[tuple[ms.Tensor]]]] = None @dataclass @@ -174,45 +182,45 @@ class GenerateEncoderDecoderOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for + at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for + at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. encoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`): - Tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. encoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True`): - Tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of + tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. decoder_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. cross_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size, generated_length, hidden_size)`. past_key_values (`tuple(tuple(mindspore.Tensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + Usually a tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. """ - sequences: ms.Tensor = None - scores: Optional[Tuple[ms.Tensor]] = None - logits: Optional[Tuple[ms.Tensor]] = None - encoder_attentions: Optional[Tuple[ms.Tensor]] = None - encoder_hidden_states: Optional[Tuple[ms.Tensor]] = None - decoder_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None - cross_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None - past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None + sequences: ms.Tensor + scores: Optional[tuple[ms.Tensor]] = None + logits: Optional[tuple[ms.Tensor]] = None + encoder_attentions: Optional[tuple[ms.Tensor]] = None + encoder_hidden_states: Optional[tuple[ms.Tensor]] = None + decoder_attentions: Optional[tuple[tuple[ms.Tensor]]] = None + cross_attentions: Optional[tuple[tuple[ms.Tensor]]] = None + decoder_hidden_states: Optional[tuple[tuple[ms.Tensor]]] = None + past_key_values: Optional[tuple[tuple[tuple[ms.Tensor]]]] = None @dataclass @@ -229,20 +237,20 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput): scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), + tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for + at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. beam_indices (`ms.Tensor`, *optional*, returned when `output_scores=True`): Beam indices of generated token id at each generation step. `ms.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`. attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. past_key_values (`tuple(tuple(ms.Tensor)))`, *optional*, returned when `use_cache=True`): Returns the model cache, used to speed up decoding. Different models have a different cache format, check @@ -251,12 +259,12 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput): sequences: ms.Tensor = None sequences_scores: Optional[ms.Tensor] = None - scores: Optional[Tuple[ms.Tensor]] = None - logits: Optional[Tuple[ms.Tensor]] = None + scores: Optional[tuple[ms.Tensor]] = None + logits: Optional[tuple[ms.Tensor]] = None beam_indices: Optional[ms.Tensor] = None - attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None - hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None - past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None + attentions: Optional[tuple[tuple[ms.Tensor]]] = None + hidden_states: Optional[tuple[tuple[ms.Tensor]]] = None + past_key_values: Optional[tuple[tuple[tuple[ms.Tensor]]]] = None @dataclass @@ -273,47 +281,47 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), + tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for + at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. beam_indices (`ms.Tensor`, *optional*, returned when `output_scores=True`): Beam indices of generated token id at each generation step. `ms.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`. encoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`): - Tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. encoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True`): - Tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of + tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. decoder_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, sequence_length)`. cross_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `ms.Tensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. past_key_values (`tuple(tuple(ms.Tensor)))`, *optional*, returned when `use_cache=True`): Returns the model cache, used to speed up decoding. Different models have a different cache format, check the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ - sequences: ms.Tensor = None + sequences: ms.Tensor sequences_scores: Optional[ms.Tensor] = None - scores: Optional[Tuple[ms.Tensor]] = None - logits: Optional[Tuple[ms.Tensor]] = None + scores: Optional[tuple[ms.Tensor]] = None + logits: Optional[tuple[ms.Tensor]] = None beam_indices: Optional[ms.Tensor] = None - encoder_attentions: Optional[Tuple[ms.Tensor]] = None - encoder_hidden_states: Optional[Tuple[ms.Tensor]] = None - decoder_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None - cross_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None - past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None + encoder_attentions: Optional[tuple[ms.Tensor]] = None + encoder_hidden_states: Optional[tuple[ms.Tensor]] = None + decoder_attentions: Optional[tuple[tuple[ms.Tensor]]] = None + cross_attentions: Optional[tuple[tuple[ms.Tensor]]] = None + decoder_hidden_states: Optional[tuple[tuple[ms.Tensor]]] = None + past_key_values: Optional[tuple[tuple[tuple[ms.Tensor]]]] = None # Typing shortcuts @@ -353,18 +361,179 @@ class GenerationMixin: To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). """ + def load_custom_generate( + self, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + trust_remote_code: Optional[bool] = None, + **kwargs, + ) -> Callable: + """ + Loads and returns a custom generate function, given a model repo. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + trust_remote_code (`bool`, *optional*): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + **kwargs: + Additional keyword arguments for remote code loading. + + Raises: + OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory. + + Returns: + A callable that can be used to generate text. + """ + # Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError + is_local_code = os.path.exists(pretrained_model_name_or_path) + has_custom_generate_folder = True + if is_local_code: + if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")): + has_custom_generate_folder = False + else: + if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"): + has_custom_generate_folder = False + + if not has_custom_generate_folder: + raise OSError( + f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a " + "`generate.py` file, can't load the custom generate function." + ) + + # Handle opt-in `trust_remote_code` and related exceptions + error_message = ( + f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override " + "the default `generate` method." + ) + resolve_trust_remote_code( + trust_remote_code, + pretrained_model_name_or_path, + has_local_code=is_local_code, + has_remote_code=not is_local_code, + error_message=error_message, + ) + + # Load the custom generate function + check_python_requirements( + pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs + ) + module = get_cached_module_file( + pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs + ) + custom_generate_function = get_class_in_module("generate", module) + return custom_generate_function + + def _cache_dependant_input_preparation( + self, + input_ids: ms.Tensor, + inputs_embeds: Optional[ms.Tensor], + cache_position: Optional[ms.Tensor], + ) -> tuple[ms.Tensor, ms.Tensor]: + """ + Generic cache-dependent input preparation + The code is put in a separate function to allow granular unit testing + as it needs a different implementation to be exportable. + + If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + - Exception 1: when passing input_embeds, input_ids may be missing entries + - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + generate the first token for each sequence. Later use the generated Input ids for continuation. + + The current implementation does not rely on ``self`` and could be + a class method. It is left as a standard method to be easily rewritten. + """ + # fixme there is no implementation for torch dynamo exporting + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( + inputs_embeds is not None # Exception 1 + or (cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + return inputs_embeds, input_ids + + def _cache_dependant_input_preparation_exporting( + self, + input_ids: ms.Tensor, + inputs_embeds: Optional[ms.Tensor], + cache_position: Optional[ms.Tensor], + ) -> tuple[ms.Tensor, ms.Tensor]: + """ + This method implements method ``_cache_dependant_input_preparation`` + with :func:`torch.cond` to make it exportable with :func:`torch.export.export`. + The code is put in a separate function to allow granular unit testing. + """ + if inputs_embeds is None: + input_ids = input_ids[:, cache_position] + else: + # This is the code we need to implemented with torch.cond. + # if input_ids.shape[1] == 0: + # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + # else: + # if cache_position[-1] >= input_ids.shape[1]: + # input_ids = input_ids[:, -cache_position.shape[0] :] + # else: + # if input_ids.shape[1] != cache_position.shape[0]: + # input_ids = input_ids[:, cache_position] + def branch_1(inputs_embeds, cache_position): + return inputs_embeds[:, -cache_position.shape[0] :] + + def branch_2(input_ids, cache_position): + return input_ids[:, -cache_position.shape[0] :] + + def branch_3(input_ids, cache_position): + return input_ids[:, cache_position] + + inputs_embeds, input_ids = mint.cond( + input_ids.shape[1] == 0, + ( + lambda input_ids, inputs_embeds, cache_position: ( + branch_1(inputs_embeds, cache_position), + input_ids, + ) + ), + ( + lambda input_ids, inputs_embeds, cache_position: ( + inputs_embeds, + mint.cond( + cache_position[-1] >= input_ids.shape[1], + branch_2, + lambda input_ids, cache_position: ( + mint.cond( + input_ids.shape[1] != cache_position.shape[0], + branch_3, + (lambda input_ids, cache_position: input_ids), + [input_ids, cache_position], + ) + ), + [input_ids, cache_position], + ), + ) + ), + [input_ids, inputs_embeds, cache_position], + ) + return inputs_embeds, input_ids def prepare_inputs_for_generation( self, input_ids, - past_key_values: Union[Cache, Tuple] = None, + past_key_values: Union[Cache, tuple] = None, attention_mask: Optional[ms.Tensor] = None, inputs_embeds: Optional[ms.Tensor] = None, cache_position: Optional[ms.Tensor] = None, **kwargs, ): """ - Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or + Prepare the model inputs for generation. It includes operations like computing the 4D attention mask or slicing inputs given the existing cache. See the forward pass in the model documentation for expected arguments (different models might have different @@ -598,8 +767,8 @@ def _prepare_model_inputs( self, inputs: Optional[ms.Tensor] = None, bos_token_id: Optional[ms.Tensor] = None, - model_kwargs: Optional[Dict[str, ms.Tensor]] = None, - ) -> Tuple[ms.Tensor, Optional[str], Dict[str, ms.Tensor]]: + model_kwargs: Optional[dict[str, ms.Tensor]] = None, + ) -> tuple[ms.Tensor, Optional[str], dict[str, ms.Tensor]]: """ This function extracts the model-specific `inputs` for generation. """ @@ -662,7 +831,7 @@ def _maybe_initialize_input_ids_for_generation( self, inputs: Optional[ms.Tensor] = None, bos_token_id: Optional[ms.Tensor] = None, - model_kwargs: Optional[Dict[str, ms.Tensor]] = None, + model_kwargs: Optional[dict[str, ms.Tensor]] = None, ) -> ms.Tensor: """Initializes input ids for generation, if necessary.""" if inputs is not None: @@ -694,7 +863,7 @@ def _prepare_attention_mask_for_generation( self, inputs_tensor: ms.Tensor, generation_config: GenerationConfig, - model_kwargs: Dict[str, Any], + model_kwargs: dict[str, Any], ) -> ms.Tensor: pad_token_id = generation_config._pad_token_tensor eos_token_id = generation_config._eos_token_tensor @@ -732,7 +901,7 @@ def _prepare_encoder_decoder_kwargs_for_generation( model_kwargs, model_input_name: Optional[str], generation_config: GenerationConfig, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: # 1. get encoder encoder = self.get_encoder() @@ -764,10 +933,10 @@ def _prepare_decoder_input_ids_for_generation( self, batch_size: int, model_input_name: str, - model_kwargs: Dict[str, ms.Tensor], + model_kwargs: dict[str, ms.Tensor], decoder_start_token_id: ms.Tensor, **ignore_kwargs, - ) -> Tuple[ms.Tensor, Dict[str, ms.Tensor]]: + ) -> tuple[ms.Tensor, dict[str, ms.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. @@ -821,7 +990,7 @@ def _expand_inputs_for_generation( is_encoder_decoder: bool = False, input_ids: Optional[ms.Tensor] = None, **model_kwargs, - ) -> Tuple[ms.Tensor, Dict[str, Any]]: + ) -> tuple[ms.Tensor, dict[str, Any]]: """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" if expand_size == 1: return input_ids, model_kwargs @@ -856,10 +1025,10 @@ def _expand_dict_for_generation(dict_to_expand): def _update_model_kwargs_for_generation( self, outputs: ModelOutput, - model_kwargs: Dict[str, Any], + model_kwargs: dict[str, Any], is_encoder_decoder: bool = False, num_new_tokens: int = 1, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: # update past_key_values keeping its naming used in model code for possible_cache_name in ALL_CACHE_NAMES: if possible_cache_name in outputs: @@ -955,7 +1124,7 @@ def _get_candidate_generator( logits_processor: LogitsProcessorList, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - model_kwargs: Dict, + model_kwargs: dict, ) -> CandidateGenerator: """ Returns the candidate generator to be used in `assisted_generation` @@ -1023,11 +1192,11 @@ def _get_candidate_generator( def _get_logits_processor( self, generation_config: GenerationConfig, - input_ids_seq_length: int, - encoder_input_ids: ms.Tensor, - prefix_allowed_tokens_fn: Callable[[int, ms.Tensor], List[int]], - logits_processor: Optional[LogitsProcessorList], - model_kwargs: Optional[Dict[str, Any]] = None, + input_ids_seq_length: Optional[int] = None, + encoder_input_ids: ms.Tensor = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, ms.Tensor], list[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, + model_kwargs: Optional[dict[str, Any]] = None, negative_prompt_ids: Optional[ms.Tensor] = None, negative_prompt_attention_mask: Optional[ms.Tensor] = None, ) -> LogitsProcessorList: @@ -1037,6 +1206,8 @@ def _get_logits_processor( """ # instantiate processors list processors = LogitsProcessorList() + if logits_processor is None: + logits_processor = [] if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: processors.append( @@ -1106,7 +1277,7 @@ def _get_logits_processor( ) if ( generation_config.min_length is not None - and generation_config._eos_token_tensor is not None + and getattr(generation_config._eos_token_tensor, "_eos_token_tensor", None) is not None and generation_config.min_length > 0 ): processors.append( @@ -1117,7 +1288,7 @@ def _get_logits_processor( ) if ( generation_config.min_new_tokens is not None - and generation_config._eos_token_tensor is not None + and getattr(generation_config._eos_token_tensor, "_eos_token_tensor", None) is not None and generation_config.min_new_tokens > 0 ): processors.append( @@ -1177,13 +1348,6 @@ def _get_logits_processor( ) ) - # Fixme - # if generation_config.forced_decoder_ids is not None: - # raise ValueError( - # "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument " - # "in favour of `input_ids` or `decoder_input_ids` respectively.", - # ) - processors = self._merge_criteria_processor_list(processors, logits_processor) # Processors previously known as `LogitsWarpers`, only applied with sampling strategies @@ -1232,7 +1396,11 @@ def _get_logits_processor( # Watermarking should be after all logits processing is finished (see #34630) if generation_config.watermarking_config is not None: - processors.append(generation_config.watermarking_config.construct_processor(self.config.vocab_size)) + processors.append( + generation_config.watermarking_config.construct_processor( + self.config.get_text_config().vocab_size + ) + ) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -1288,7 +1456,7 @@ def _merge_criteria_processor_list( Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same processor/criteria is present on both lists, use the user-defined one. - (Note: up to v4.49.0, this funtion threw an exception is the same logit processor was found twice.) + (Note: up to v4.49.0, this function threw an exception is the same logit processor was found twice.) """ if len(custom_list) == 0: return default_list @@ -1328,16 +1496,16 @@ def compute_transition_scores( used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time. Parameters: - sequences (`torch.LongTensor`): + sequences (`ms.Tensor`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)`): + scores (`tuple(ms.Tensor)`): Transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*): - Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + beam_indices (`ms.Tensor`, *optional*): + Beam indices of generated token id at each generation step. `ms.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at generate-time. normalize_logits (`bool`, *optional*, defaults to `False`): @@ -1411,7 +1579,7 @@ def compute_transition_scores( # 3. Optionally normalize the logits (across the vocab dimension) if normalize_logits: - scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1]) + scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1]) scores = mint.nn.functional.log_softmax(scores, dim=1) scores = scores.reshape(-1, scores.shape[-1]) @@ -1425,7 +1593,7 @@ def compute_transition_scores( beam_indices[beam_indices_mask] = 0 # 6. multiply beam_indices with vocab size to gather correctly from scores - beam_sequence_indices = beam_indices * self.config.vocab_size + beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size # 7. Define which indices contributed to scores cut_idx = sequences.shape[-1] - max_beam_length @@ -1492,15 +1660,8 @@ def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): f"to `generate()` {doc_reference}." ) - def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" - # If a `Cache` instance is passed, checks whether the model is compatible with it - if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: - raise ValueError( - f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " - "check the model documentation for supported cache formats." - ) - # Excludes arguments that are handled before calling any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: @@ -1649,8 +1810,8 @@ def _prepare_generated_length( return generation_config def _prepare_generation_config( - self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict - ) -> Tuple[GenerationConfig, Dict]: + self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict + ) -> tuple[GenerationConfig, dict]: """ Prepares the base generation config, then applies any generation configuration options from kwargs. This function handles retrocompatibility with respect to configuration files. @@ -1686,6 +1847,8 @@ def _prepare_generation_config( generation_config = self.generation_config using_model_generation_config = True + # `torch.export.export` usually raises an exception if it is called + # with ``strict=True``. deepcopy can only be processed if ``strict=False``. generation_config = copy.deepcopy(generation_config) if not using_model_generation_config: @@ -1694,19 +1857,28 @@ def _prepare_generation_config( # - otherwise: legacy behavior, let's just make sure we have the tokens defined model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version) if use_model_defaults is True or ( - use_model_defaults is None and model_base_version >= version.parse("4.50.0") + use_model_defaults is None and model_base_version >= version.parse("4.50.0") ): modified_values = {} - default_generation_config = GenerationConfig() - for key, default_value in default_generation_config.__dict__.items(): + global_default_generation_config = GenerationConfig() + model_generation_config = self.generation_config + # we iterate over the model's generation config: it may hold custom keys, which we'll want to copy + for key, model_gen_config_value in model_generation_config.__dict__.items(): if key.startswith("_") or key == "transformers_version": # metadata continue - custom_gen_config_value = getattr(generation_config, key) - model_gen_config_value = getattr(self.generation_config, key) - if custom_gen_config_value == default_value and model_gen_config_value != default_value: + global_default_value = getattr(global_default_generation_config, key, None) + custom_gen_config_value = getattr(generation_config, key, None) + if ( + custom_gen_config_value == global_default_value + and model_gen_config_value != global_default_value + ): modified_values[key] = model_gen_config_value setattr(generation_config, key, model_gen_config_value) - if len(modified_values) > 0: + # edge case: we may set `temperature=0.0` and `do_sample=False`, but the model defaults to + # `do_sample=True` + if generation_config.temperature == 0.0: + generation_config.do_sample = False + if use_model_defaults is None and len(modified_values) > 0: logger.warning_once( f"`generation_config` default values have been modified to match model-specific defaults: " f"{modified_values}. If this is not desired, please set these values explicitly." @@ -1728,6 +1900,8 @@ def _prepare_generation_config( def _get_initial_cache_position(self, input_ids, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + if "cache_position" in model_kwargs and model_kwargs["cache_position"]: + return model_kwargs # the lines below are equivalent to `mint.arange` [0,1,2,3, .., input_shape-1] if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: cache_position = mint.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=ms.int32).cumsum(0) - 1 @@ -1771,6 +1945,9 @@ def _get_cache( Returns the resulting cache object. """ + if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""): + cache_implementation = "hybrid_chunked" + cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None @@ -1786,9 +1963,8 @@ def _get_cache( not hasattr(self, "_cache") or (not isinstance(cache_to_check, cache_cls)) or cache_to_check.max_batch_size != batch_size + or cache_to_check.max_cache_len < max_cache_len ) - if cache_implementation != "mamba": - need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len if requires_cross_attention_cache and hasattr(self, "_cache"): need_new_cache = ( @@ -1808,6 +1984,9 @@ def _get_cache( "max_cache_len": max_cache_len, "dtype": cache_dtype, } + if cache_implementation in ["static", "hybrid", "offloaded_static"]: + cache_kwargs.update({"tp_size": self.tp_size}) + self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: encoder_kwargs = cache_kwargs.copy() @@ -1816,7 +1995,7 @@ def _get_cache( else: self._cache.reset() return self._cache - + def _supports_default_dynamic_cache(self) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. @@ -1847,7 +2026,7 @@ def _supports_default_dynamic_input(self) -> bool: def _prepare_legacy_cache( self, generation_config: GenerationConfig, - model_kwargs: Dict, + model_kwargs: dict, cache_name: str, batch_size: int, ): @@ -1886,7 +2065,7 @@ def _prepare_legacy_cache( def _prepare_cache_for_generation( self, generation_config: GenerationConfig, - model_kwargs: Dict, + model_kwargs: dict, assistant_model: "PreTrainedModel", batch_size: int, max_cache_length: int, @@ -1895,8 +2074,10 @@ def _prepare_cache_for_generation( Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is instantiated, writes it to `model_kwargs`, under the name expected by the model. """ - + + is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"]) cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None ) @@ -1954,6 +2135,9 @@ def _prepare_cache_for_generation( ) generation_config.cache_implementation = None + generation_config.cache_implementation = generation_config.cache_implementation or getattr( + self.config.get_text_config(decoder=True), "cache_implementation", None + ) if generation_config.cache_implementation is not None: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static" and not self._supports_static_cache: @@ -2134,19 +2318,48 @@ def _padding_inputs( return new_input_ids, new_inputs_embeds, new_labels, new_position_ids, new_attention_mask + def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: GenerationConfig) -> bool: + """ + Determines whether to trigger auto-compilation of the model's forward pass at generation time. + """ + # Override: honor `disable_compile` flag + if generation_config.disable_compile: + return False + + # Base logic + valid_hardware = ms.get_context("mode")==0 or bool( + generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices + ) + using_compilable_cache = ( + isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable + ) + # TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile) + can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph + + # Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn + # them + if generation_config.compile_config is not None and not can_compile: + logger.warning_once( + "You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation " + "will be skipped." + ) + + return can_compile + def generate( self, inputs: Optional[ms.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, ms.Tensor], List[int]]] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, ms.Tensor], list[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, negative_prompt_ids: Optional[ms.Tensor] = None, negative_prompt_attention_mask: Optional[ms.Tensor] = None, use_model_defaults: Optional[bool] = None, + custom_generate: Optional[str] = None, **kwargs, ) -> Union[tuple, ms.Tensor]: r""" @@ -2187,13 +2400,13 @@ def generate( generation config an error is thrown. If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is intended for advanced users. - prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], List[int]]`, *optional*): + prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], list[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity - Retrieval](https://arxiv.org/abs/2010.00904). + Retrieval](https://huggingface.co/papers/2010.00904). synced_gpus (`bool`, *optional*): Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished @@ -2216,7 +2429,12 @@ def generate( generation configuration (`model.generation_config`), as opposed to the global defaults (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be `True`. - kwargs (`Dict[str, Any]`, *optional*): + custom_generate (`str`, *optional*): + A string containing the name of a huggingface.co repository. If provided, the custom `generate` + function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the + standard `generate` method. Note that the logic is for generation is entirely defined in that + repository, and the return type may be different from the standard `generate` method. + kwargs (`dict[str, Any]`, *optional*): Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. @@ -2237,9 +2455,28 @@ def generate( - [`~generation.GenerateEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`] """ + # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead + trust_remote_code = kwargs.pop("trust_remote_code", None) + if custom_generate is not None: + # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: + # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to + # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`. + global_keys_to_exclude = { + "self", + "kwargs", + "global_keys_to_exclude", + "trust_remote_code", + "custom_generate", + } + generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude} + generate_arguments.update(kwargs) + + custom_generate_function = self.load_custom_generate( + custom_generate, trust_remote_code=trust_remote_code, **kwargs + ) + return custom_generate_function(model=self, **generate_arguments) # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - self._validate_model_class() tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation @@ -2321,7 +2558,7 @@ def generate( streamer.put(input_ids.asnumpy()) # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_length = input_ids.shape[-1] + input_ids_length = input_ids.shape[1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config = self._prepare_generated_length( @@ -2624,6 +2861,31 @@ def _sample( unfinished_sequences = ops.ones(batch_size, dtype=ms.int32) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_forward = self.__call__ + compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) + if compile_forward: + os.environ["TOKENIZERS_PARALLELISM"] = "0" + # If we use FA2 and a static cache, we cannot compile with fullgraph + if self.config._attn_implementation == "flash_attention_2" and getattr( + model_kwargs.get("past_key_values"), "is_compileable", False + ): + if generation_config.compile_config is None: + generation_config.compile_config = CompileConfig(fullgraph=False) + # only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user) + elif generation_config.compile_config.fullgraph: + logger.warning_once( + "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as " + "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`." + ) + generation_config.compile_config.fullgraph = False + model_forward = self.get_compiled_call(generation_config.compile_config) + + if generation_config.prefill_chunk_size is not None: + model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) + is_prefill = False + else: + is_prefill = True + multinomial = get_multinomial_op() step = 0 s_time = time.time() @@ -2646,10 +2908,18 @@ def _sample( model_inputs.update({"output_hidden_states": output_hidden_states}) # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=False if ms.get_context("mode") == ms.GRAPH_MODE else True, - ) + if is_prefill: + outputs = self( + **model_inputs, + return_dict=False if ms.get_context("mode") == ms.GRAPH_MODE else True, + ) + is_prefill = False + else: + outputs = model_forward( + **model_inputs, + return_dict=False if ms.get_context("mode") == ms.GRAPH_MODE else True, + ) + if not isinstance(outputs, ModelOutput): outputs = ModelOutput( loss=None, @@ -2877,7 +3147,7 @@ def _get_top_k_continuations( num_beams: int, vocab_size: int, batch_size: int, - ) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]: + ) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor]: """ Get top-K continuations given the accumulated log probs on the next token. A few notes to understand what's going on: @@ -2925,7 +3195,7 @@ def _get_running_beams_for_next_iteration( topk_running_beam_indices: ms.Tensor, next_token_hits_stopping_criteria: ms.Tensor, num_beams: int, - ) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]: + ) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor]: """ Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the best non-finished beams to continue beam search in the next iteration. @@ -2948,6 +3218,7 @@ def _update_finished_beams( topk_log_probs: ms.Tensor, beam_indices: ms.Tensor, topk_running_beam_indices: ms.Tensor, + is_early_stop_heuristic_unsatisfied: ms.Tensor, is_sent_finished: ms.Tensor, next_token_hits_stopping_criteria: ms.Tensor, top_num_beam_mask: ms.Tensor, @@ -2956,7 +3227,7 @@ def _update_finished_beams( decoder_prompt_len: int, length_penalty: float, early_stopping: Union[bool, str], - ) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]: + ) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]: """ Updates the finished beams if (and only if) there are new completed sequences that have a higher score than the current finished sequences. @@ -2974,6 +3245,9 @@ def _update_finished_beams( early_stopping is True, ms.int32 ) topk_log_probs += beams_in_batch_are_full.to(ms.float32) * -1.0e9 + # - make sure no scores can be added anymore if improvement is not possible + topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(ms.float32) * -1.0e9 + # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9 @@ -3049,7 +3323,7 @@ def _beam_search( num_beams = generation_config.num_beams num_return_sequences = generation_config.num_return_sequences - batch_size_unflattened, cur_len = input_ids.shape + batch_size_unflattened, cur_len = input_ids.shape[:2] batch_size = batch_size_unflattened // num_beams # TODO (joao): standardize special cases if self.__class__.__name__ == "MoshiDepthDecoder": @@ -3120,6 +3394,9 @@ def _beam_search( # per batch, beam-item state bit indicating if sentence has finished. is_sent_finished = mint.zeros((batch_size, num_beams), dtype=ms.bool_) + # per batch state bit indicating if there is a possibility to improve the best finished sentence. + is_early_stop_heuristic_unsatisfied = mint.ones((batch_size, 1), dtype=ms.bool_) + # per batch, beam-item state bit indicating if there are valid continuations. next_token_hits_stopping_criteria = mint.zeros((batch_size, num_beams), dtype=ms.bool_) @@ -3227,6 +3504,7 @@ def _beam_search( topk_log_probs=topk_log_probs, beam_indices=beam_indices, topk_running_beam_indices=topk_running_beam_indices, + is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied, is_sent_finished=is_sent_finished, next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, top_num_beam_mask=top_num_beam_mask, @@ -3301,3 +3579,43 @@ def _beam_search( ) else: return sequences + + def _prefill_chunking(self, input_ids: ms.Tensor, generation_config: GenerationConfig, **model_kwargs): + chunk_size = generation_config.prefill_chunk_size + # Only chunk up the token just before last, so that decoding is completely performed outside this function + # (here we simply prefill the cache) + input_chunks = mint.split(input_ids[:, :-1], chunk_size, dim=-1) + + if "past_key_values" not in model_kwargs: + raise ValueError("Cannot use prefill chunking without a cache") + + model_forward = self.construct + + compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) + if compile_forward: + model_forward = self.get_compiled_call(generation_config.compile_config) + + attention_mask = model_kwargs.pop("attention_mask", None) + + past_length = 0 + for input_chunk in input_chunks: + current_length = past_length + input_chunk.shape[-1] + # Prepare inputs + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask[:, :current_length] + model_kwargs["cache_position"] = mint.arange( + past_length, current_length, dtype=ms.int64 + ) + model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0) + model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs) + + outputs = model_forward(**model_inputs, return_dict=True) + + model_kwargs["past_key_values"] = outputs.past_key_values + past_length = current_length + + model_kwargs["attention_mask"] = attention_mask + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + _ = model_kwargs.pop("position_ids", None) + + return model_kwargs \ No newline at end of file diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 01de7a138a..1252dd8ff1 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -247,7 +247,7 @@ def dtype_byte_size(dtype): def shard_checkpoint( - state_dict: Dict[str, Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME + state_dict: dict[str, Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME ): """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a @@ -266,7 +266,7 @@ def shard_checkpoint( Args: - state_dict (`Dict[str, Tensor]`): The state dictionary of a model to save. + state_dict (`dict[str, Tensor]`): The state dictionary of a model to save. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). @@ -445,10 +445,10 @@ def _get_mindspore_dtype( def _find_missing_and_unexpected_keys( cls, model: "PreTrainedModel", - original_checkpoint_keys: List[str], - checkpoint_keys: List[str], + original_checkpoint_keys: list[str], + checkpoint_keys: list[str], loading_base_model_from_task_state_dict: bool, -) -> Tuple[List[str], List[str]]: +) -> tuple[list[str], list[str]]: """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys (keys found in the loaded state dict keys, but that are NOT part of the model parameters) """ @@ -566,7 +566,7 @@ def create_extended_attention_mask_for_decoder(input_shape, attention_mask): return extended_attention_mask def get_extended_attention_mask( - self, attention_mask: Tensor, input_shape: Tuple[int], dtype: ms.float32 = None + self, attention_mask: Tensor, input_shape: tuple[int], dtype: ms.float32 = None ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. @@ -574,7 +574,7 @@ def get_extended_attention_mask( Arguments: attention_mask (`Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (`Tuple[int]`): + input_shape (`tuple[int]`): The shape of the input to the model. Returns: @@ -683,7 +683,7 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool return sum(total_numel) - def estimate_tokens(self, input_dict: Dict[str, Union[ms.Tensor, Any]]) -> int: + def estimate_tokens(self, input_dict: dict[str, Union[ms.Tensor, Any]]) -> int: """ Helper function to estimate the total number of tokens from the model inputs. @@ -704,7 +704,7 @@ def estimate_tokens(self, input_dict: Dict[str, Union[ms.Tensor, Any]]) -> int: self.warnings_issued["estimate_tokens"] = True return 0 - def floating_point_ops(self, input_dict: Dict[str, Union[ms.Tensor, Any]], exclude_embeddings: bool = True) -> int: + def floating_point_ops(self, input_dict: dict[str, Union[ms.Tensor, Any]], exclude_embeddings: bool = True) -> int: """ Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a batch with this transformer model. Default approximation neglects the quadratic dependency on the number of @@ -883,6 +883,8 @@ class PreTrainedModel(nn.Cell, EmbeddingAccessMixin, ModuleUtilsMixin, Generatio # SDPA support _supports_sdpa = False + _can_compile_fullgraph = False + # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? _supports_cache_class = False _supports_static_cache = False @@ -942,9 +944,9 @@ class LlamaModel(PreTrainedModel): return self._can_record_outputs or {} @property - def dummy_inputs(self) -> Dict[str, Tensor]: + def dummy_inputs(self) -> dict[str, Tensor]: """ - `Dict[str, Tensor]`: Dummy inputs to do a forward pass in the network. + `dict[str, Tensor]`: Dummy inputs to do a forward pass in the network. """ return {"input_ids": Tensor(DUMMY_INPUTS)} @@ -1532,8 +1534,8 @@ def tie_weights(self): def _tie_encoder_decoder_weights( encoder: nn.Cell, decoder: nn.Cell, base_model_prefix: str, base_encoder_name: str ): - uninitialized_encoder_weights: List[str] = [] - tied_weights: List[str] = [] + uninitialized_encoder_weights: list[str] = [] + tied_weights: list[str] = [] if decoder.__class__ != encoder.__class__: logger.info( f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" @@ -1545,7 +1547,7 @@ def tie_encoder_to_decoder_recursively( encoder_pointer: nn.Cell, module_name: str, base_encoder_name: str, - uninitialized_encoder_weights: List[str], + uninitialized_encoder_weights: list[str], depth=0, total_decoder_name="", total_encoder_name="", @@ -1910,7 +1912,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" ) - def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]: + def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]: raise NotImplementedError( f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" @@ -1993,7 +1995,7 @@ def save_pretrained( For backward compatibility with PEFT library, in case adapter weights are attached to the model, all keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can disable this behaviours by setting `save_peft_format` to `False`. - kwargs (`Dict[str, Any]`, *optional*): + kwargs (`dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ use_auth_token = kwargs.pop("use_auth_token", None) @@ -2224,7 +2226,7 @@ def from_pretrained( save directory. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory. - state_dict (`Dict[str, Tensor]`, *optional*): + state_dict (`dict[str, Tensor]`, *optional*): A state dictionary to use instead of a state dictionary loaded from saved weights file. This option can be used if you want to create a model from a pretrained configuration but load your own @@ -2249,7 +2251,7 @@ def from_pretrained( resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. - proxies (`Dict[str, str]`, *optional*): + proxies (`dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): @@ -2869,7 +2871,7 @@ def from_pretrained( return model @staticmethod - def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]: + def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]: """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) # This rename is logged. @@ -2882,8 +2884,8 @@ def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]: def _get_key_renaming_mapping( self, - checkpoint_keys: List[str], - key_mapping: Optional[Dict[str, str]] = None, + checkpoint_keys: list[str], + key_mapping: Optional[dict[str, str]] = None, loading_base_model_from_task_state_dict: bool = False, loading_task_model_from_base_state_dict: bool = False, ): @@ -2957,7 +2959,7 @@ def _load_pretrained_model( sharded_metadata=None, dtype=None, keep_in_fp32_modules=None, - key_mapping: Optional[Dict[str, str]] = None, + key_mapping: Optional[dict[str, str]] = None, weights_only: bool = True, ): model_state_dict = {k: v for k, v in model.parameters_and_names()} @@ -3181,10 +3183,24 @@ def _find_mismatched_keys( return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + def get_compiled_call(self) -> Callable: + """Return a `mindspore.jit`'d version of `self.__call__`. This is useful to dynamically choose between + non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't + want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding + (where we want the speed-ups of compiled version with static shapes).""" + # Only reset it if not present or different from previous config + if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT + return self.__call__ + if ( + not hasattr(self, "_compiled_call") + ): + self._compiled_call = ms.jit(self.__call__) + return self._compiled_call + def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): module_keys = {".".join(key.split(".")[:-1]) for key in names} - # torch.nn.ParameterList is a special case where two parameter keywords + # torch.nn.Parameterlist is a special case where two parameter keywords # are appended to the module name, *e.g.* bert.special_embeddings.0 module_keys = module_keys.union( {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()} @@ -3496,7 +3512,7 @@ def construct( is_impossible: Optional[ms.Tensor] = None, p_mask: Optional[ms.Tensor] = None, return_dict: bool = False, - ) -> Union[SquadHeadOutput, Tuple[ms.Tensor]]: + ) -> Union[SquadHeadOutput, tuple[ms.Tensor]]: """ Args: hidden_states (`mindspore.Tensor` of shape `(batch_size, seq_len, hidden_size)`): @@ -3722,7 +3738,7 @@ def __len__(self): def register(cls, key: str, value: Callable): cls._global_mapping.update({key: value}) - def valid_keys(self) -> List[str]: + def valid_keys(self) -> list[str]: return list(self.keys()) From 17da5c731c37804b6c0fb61ed83347ed6d211118 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 11:18:30 +0800 Subject: [PATCH 07/94] feat(transformers): add ernie4.5 for validation --- mindone/transformers/__init__.py | 7 +- mindone/transformers/models/__init__.py | 1 + .../transformers/models/aria/modeling_aria.py | 2 +- .../models/cohere2/modeling_cohere2.py | 2 +- .../transformers/models/ernie4_5/__init__.py | 1 + .../models/ernie4_5/modeling_ernie4_5.py | 468 ++++++++++++++++++ .../models/gemma/modeling_gemma.py | 2 +- .../transformers/models/glm/modeling_glm.py | 3 +- .../transformers/models/glm4/modeling_glm4.py | 2 +- .../models/glm4v/modeling_glm4v.py | 3 +- .../models/granite/modeling_granite.py | 2 +- .../models/helium/modeling_helium.py | 2 +- .../models/llama/modeling_llama.py | 3 +- .../models/mixtral/modeling_mixtral.py | 2 +- .../transformers/models/phi3/modeling_phi3.py | 3 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen3/modeling_qwen3.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- 18 files changed, 494 insertions(+), 15 deletions(-) create mode 100644 mindone/transformers/models/ernie4_5/__init__.py create mode 100644 mindone/transformers/models/ernie4_5/modeling_ernie4_5.py diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 546edbc7f6..70a06c1910 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -21,7 +21,7 @@ # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # in the namespace without actually importing anything (and especially none of the backends). -__version__ = "4.54.0" +__version__ = "4.54.1" import transformers from packaging import version @@ -274,6 +274,11 @@ DebertaV2PreTrainedModel, ) from .models.dpt import DPTForDepthEstimation +from .models.ernie4_5 import ( + Ernie4_5PreTrainedModel, + Ernie4_5Model, + Ernie4_5ForCausalLM +) from .models.fuyu import FuyuForCausalLM, FuyuPreTrainedModel from .models.gemma import ( GemmaForCausalLM, diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 2fd7458cff..9ec2a7fc48 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -35,6 +35,7 @@ clip, convbert, dpt, + ernie4_5, fuyu, gemma, gemma2, diff --git a/mindone/transformers/models/aria/modeling_aria.py b/mindone/transformers/models/aria/modeling_aria.py index 4e9a5f5c25..1f0229f8ea 100644 --- a/mindone/transformers/models/aria/modeling_aria.py +++ b/mindone/transformers/models/aria/modeling_aria.py @@ -26,12 +26,12 @@ from transformers import AriaConfig, AriaTextConfig from transformers.utils import ( - LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) +from ...utils import LossKwargs import mindspore as ms import mindspore.mint as mint diff --git a/mindone/transformers/models/cohere2/modeling_cohere2.py b/mindone/transformers/models/cohere2/modeling_cohere2.py index c63a6cdd57..8c7486ccde 100644 --- a/mindone/transformers/models/cohere2/modeling_cohere2.py +++ b/mindone/transformers/models/cohere2/modeling_cohere2.py @@ -26,12 +26,12 @@ from transformers import Cohere2Config from transformers.utils import ( - LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) +from ...utils import LossKwargs from transformers.utils.deprecation import deprecate_kwarg import mindspore as ms diff --git a/mindone/transformers/models/ernie4_5/__init__.py b/mindone/transformers/models/ernie4_5/__init__.py new file mode 100644 index 0000000000..ba97e39a99 --- /dev/null +++ b/mindone/transformers/models/ernie4_5/__init__.py @@ -0,0 +1 @@ +from .modeling_ernie4_5 import Ernie4_5PreTrainedModel, Ernie4_5Model, Ernie4_5ForCausalLM \ No newline at end of file diff --git a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py new file mode 100644 index 0000000000..9efebed5ce --- /dev/null +++ b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -0,0 +1,468 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ernie4_5/modular_ernie4_5.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ernie4_5.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import mindspore as ms +from mindspore import mint, nn, Parameter + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs +from transformers.utils import auto_docstring, can_return_tuple +from transformers.models.ernie4_5.configuration_ernie4_5 import Ernie4_5Config + + +class Ernie4_5RotaryEmbedding(nn.Cell): + def __init__(self, config: Ernie4_5Config): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @ms._no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def construct(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().broadcast_to((position_ids.shape[0], -1, 1)) + position_ids_expanded = position_ids[:, None, :].float() + + # fixme there is not implementation for torch.autocast + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).swapaxes(1, 2) + emb = mint.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + # keeping it in full precision + return cos, sin + + +class Ernie4_5MLP(nn.Cell): + def __init__(self, config: Ernie4_5Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def construct(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return mint.stack((-x2, x1), dim=-1).flatten(-2) + + +def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim)) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mint.matmul(query, key_states.swapaxes(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.swapaxes(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`ms.Tensor`): The query tensor. + k (`ms.Tensor`): The key tensor. + cos (`ms.Tensor`): The cosine part of the rotary embedding. + sin (`ms.Tensor`): The sine part of the rotary embedding. + position_ids (`ms.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # glm rope style (with full dim) and full precision + original_dtype = q.dtype + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + q_embed = (q.float() * cos) + (rotate_half(q).float() * sin) + k_embed = (k.float() * cos) + (rotate_half(k).float() * sin) + + return q_embed.to(original_dtype), k_embed.to(original_dtype) + + +class Ernie4_5Attention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Ernie4_5Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.attention_dropout = 0.0 + self.is_causal = True + + self.q_proj = mint.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = mint.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = mint.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = mint.nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + def construct( + self, + hidden_states: ms.Tensor, + position_embeddings: tuple[ms.Tensor, ms.Tensor], + attention_mask: Optional[ms.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[ms.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ms.Tensor, ms.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Ernie4_5RMSNorm(nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + Ernie4_5RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = Parameter(mint.ones(hidden_size)) + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Ernie4_5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Ernie4_5Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Ernie4_5Attention(config=config, layer_idx=layer_idx) + + self.mlp = Ernie4_5MLP(config) + self.input_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ms.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Ernie4_5PreTrainedModel(PreTrainedModel): + config: Ernie4_5Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Ernie4_5DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Ernie4_5DecoderLayer, + "attentions": Ernie4_5Attention, + } + + +@auto_docstring +class Ernie4_5Model(Ernie4_5PreTrainedModel): + def __init__(self, config: Ernie4_5Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = mint.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.CellList( + [Ernie4_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Ernie4_5RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + cache_position: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: ms.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: ms.Tensor = mint.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Ernie4_5ForCausalLM(Ernie4_5PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Ernie4_5Model(config) + self.vocab_size = config.vocab_size + self.lm_head = mint.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + logits_to_keep: Union[int, ms.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Ernie4_5ForCausalLM", "Ernie4_5Model", "Ernie4_5PreTrainedModel"] diff --git a/mindone/transformers/models/gemma/modeling_gemma.py b/mindone/transformers/models/gemma/modeling_gemma.py index ce987857fa..7c600a494e 100644 --- a/mindone/transformers/models/gemma/modeling_gemma.py +++ b/mindone/transformers/models/gemma/modeling_gemma.py @@ -25,7 +25,7 @@ from typing import Callable, List, Optional, Tuple, Union from transformers.models.gemma.configuration_gemma import GemmaConfig -from transformers.utils import LossKwargs, logging +from ...utils import LossKwargs, logging import mindspore as ms from mindspore import mint, nn, ops diff --git a/mindone/transformers/models/glm/modeling_glm.py b/mindone/transformers/models/glm/modeling_glm.py index 1bc9200fbb..a34009917d 100644 --- a/mindone/transformers/models/glm/modeling_glm.py +++ b/mindone/transformers/models/glm/modeling_glm.py @@ -26,7 +26,8 @@ from typing import Callable, Optional, Tuple, Union from transformers import GlmConfig -from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops diff --git a/mindone/transformers/models/glm4/modeling_glm4.py b/mindone/transformers/models/glm4/modeling_glm4.py index 455c15357a..6e8dcf85b4 100644 --- a/mindone/transformers/models/glm4/modeling_glm4.py +++ b/mindone/transformers/models/glm4/modeling_glm4.py @@ -25,7 +25,7 @@ from typing import Callable, Optional, Tuple, Union from transformers.models.glm4.configuration_glm4 import Glm4Config -from transformers.utils import LossKwargs +from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops diff --git a/mindone/transformers/models/glm4v/modeling_glm4v.py b/mindone/transformers/models/glm4v/modeling_glm4v.py index 773fff8b2b..16dad90618 100644 --- a/mindone/transformers/models/glm4v/modeling_glm4v.py +++ b/mindone/transformers/models/glm4v/modeling_glm4v.py @@ -26,7 +26,8 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from transformers.utils import LossKwargs, logging +from transformers.utils import logging +from ...utils import LossKwargs import mindspore as ms import mindspore.mint.nn.functional as F diff --git a/mindone/transformers/models/granite/modeling_granite.py b/mindone/transformers/models/granite/modeling_granite.py index dec2914e8c..2eee7a918d 100644 --- a/mindone/transformers/models/granite/modeling_granite.py +++ b/mindone/transformers/models/granite/modeling_granite.py @@ -26,12 +26,12 @@ from transformers.models.granite.configuration_granite import GraniteConfig from transformers.utils import ( - LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) +from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops diff --git a/mindone/transformers/models/helium/modeling_helium.py b/mindone/transformers/models/helium/modeling_helium.py index 9841d6fee2..46ce52d98e 100644 --- a/mindone/transformers/models/helium/modeling_helium.py +++ b/mindone/transformers/models/helium/modeling_helium.py @@ -3,13 +3,13 @@ from transformers import HeliumConfig from transformers.utils import ( - LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) +from ...utils import LossKwargs from transformers.utils.deprecation import deprecate_kwarg import mindspore as ms diff --git a/mindone/transformers/models/llama/modeling_llama.py b/mindone/transformers/models/llama/modeling_llama.py index e6d888a983..1e9fe75258 100644 --- a/mindone/transformers/models/llama/modeling_llama.py +++ b/mindone/transformers/models/llama/modeling_llama.py @@ -24,7 +24,8 @@ import numpy as np from transformers import LlamaConfig -from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import LossKwargs import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops diff --git a/mindone/transformers/models/mixtral/modeling_mixtral.py b/mindone/transformers/models/mixtral/modeling_mixtral.py index a7fdbc461d..dfab7eec42 100644 --- a/mindone/transformers/models/mixtral/modeling_mixtral.py +++ b/mindone/transformers/models/mixtral/modeling_mixtral.py @@ -31,13 +31,13 @@ from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.utils import ( - LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) +from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops diff --git a/mindone/transformers/models/phi3/modeling_phi3.py b/mindone/transformers/models/phi3/modeling_phi3.py index 6ceaa1a281..2a0e46cf22 100644 --- a/mindone/transformers/models/phi3/modeling_phi3.py +++ b/mindone/transformers/models/phi3/modeling_phi3.py @@ -27,7 +27,8 @@ from typing import Callable, List, Optional, Tuple, Union from transformers.models.phi3.configuration_phi3 import Phi3Config -from transformers.utils import LossKwargs, logging +from transformers.utils import logging +from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py index a67922b558..24f9a2cda6 100644 --- a/mindone/transformers/models/qwen2/modeling_qwen2.py +++ b/mindone/transformers/models/qwen2/modeling_qwen2.py @@ -15,7 +15,7 @@ from typing import Callable, List, Optional, Tuple, Union from transformers import Qwen2Config, logging -from transformers.utils import LossKwargs +from ...utils import LossKwargs import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops diff --git a/mindone/transformers/models/qwen3/modeling_qwen3.py b/mindone/transformers/models/qwen3/modeling_qwen3.py index 0589fe0d75..590bed7ec9 100644 --- a/mindone/transformers/models/qwen3/modeling_qwen3.py +++ b/mindone/transformers/models/qwen3/modeling_qwen3.py @@ -29,13 +29,13 @@ import numpy as np from transformers.models.qwen3.configuration_qwen3 import Qwen3Config from transformers.utils import ( - LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) +from ...utils import LossKwargs import mindspore as ms from mindspore import Tensor, mint, nn diff --git a/mindone/transformers/models/starcoder2/modeling_starcoder2.py b/mindone/transformers/models/starcoder2/modeling_starcoder2.py index 0884d5d07d..f9a753f7b1 100644 --- a/mindone/transformers/models/starcoder2/modeling_starcoder2.py +++ b/mindone/transformers/models/starcoder2/modeling_starcoder2.py @@ -31,13 +31,13 @@ from transformers.models.starcoder2.configuration_starcoder2 import Starcoder2Config from transformers.utils import ( # can_return_tuple, - LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) +from ...utils import LossKwargs from transformers.utils.deprecation import deprecate_kwarg import mindspore as ms From fd769b1ca66bbb25832b450a8a58d378ff74c458 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 14:34:22 +0800 Subject: [PATCH 08/94] fix get_type_hints problem --- mindone/transformers/modeling_utils.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 1252dd8ff1..4fcb0832e2 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -959,27 +959,8 @@ def framework(self) -> str: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - # For BC we keep the original `config_class` definition in case - # there is a `config_class` attribute (e.g. remote code models), - # otherwise we derive it from the annotated `config` attribute. - - # defined in this particular subclass - child_annotation = cls.__dict__.get("__annotations__", {}).get("config", None) - child_attribute = cls.__dict__.get("config_class", None) - - # defined in the class (this subclass or any parent class) - full_annotation = get_type_hints(cls).get("config", None) - full_attribute = cls.config_class - - # priority (child class_config -> child annotation -> global class_config -> global annotation) - if child_attribute is not None: - cls.config_class = child_attribute - elif child_annotation is not None: - cls.config_class = child_annotation - elif full_attribute is not None: - cls.config_class = full_attribute - elif full_annotation is not None: - cls.config_class = full_annotation + # fixme get_type_hints would encounter a problem during typing check + pass def __init__(self, config: PretrainedConfig, *inputs, **kwargs): super().__init__() From 37fe594d87239233bb8b0991597a5783d0454385 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 14:50:26 +0800 Subject: [PATCH 09/94] fix get_type_hints problem --- mindone/transformers/models/ernie4_5/modeling_ernie4_5.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py index 9efebed5ce..4fc3be0e2b 100644 --- a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -299,7 +299,8 @@ def construct( @auto_docstring class Ernie4_5PreTrainedModel(PreTrainedModel): - config: Ernie4_5Config + # fixme check with PretrainedModel.__init__subclass__ + config = Ernie4_5Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Ernie4_5DecoderLayer"] From 0874e773a97896e4e5f1ea97377fca79a49dccb8 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 14:59:35 +0800 Subject: [PATCH 10/94] fix get_type_hints problem --- mindone/transformers/models/ernie4_5/modeling_ernie4_5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py index 4fc3be0e2b..b2cb6b799f 100644 --- a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -300,7 +300,7 @@ def construct( @auto_docstring class Ernie4_5PreTrainedModel(PreTrainedModel): # fixme check with PretrainedModel.__init__subclass__ - config = Ernie4_5Config + config_class = Ernie4_5Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Ernie4_5DecoderLayer"] From 94fb78b7208c4433b965fe36eaed2e10ed7f62e7 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 16:46:19 +0800 Subject: [PATCH 11/94] fix metadata.get keyerror --- mindone/transformers/modeling_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 4fcb0832e2..80b7cddf0f 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -2719,7 +2719,9 @@ def from_pretrained( with safe_open(resolved_archive_file, framework="np") as f: metadata = f.metadata() - if metadata.get("format") in ("np", "pt"): + if metadata is None: + pass + elif metadata.get("format") in ("np", "pt"): pass elif metadata.get("format") == "tf": from_tf = True From bf69ef959cde86fd55efabbf5cbc29ab1e1fd522 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 16:52:32 +0800 Subject: [PATCH 12/94] fix masking_utils alignment --- mindone/transformers/models/ernie4_5/modeling_ernie4_5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py index b2cb6b799f..3965b0f54a 100644 --- a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -365,13 +365,13 @@ def construct( if position_ids is None: position_ids = cache_position.unsqueeze(0) + # fixme to add position_ids, masking_utils should be aligned with transformers causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - position_ids=position_ids, ) hidden_states = inputs_embeds From a1be89c816e7ee97ed71deb9b405e5cb25b6dd23 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 16:55:40 +0800 Subject: [PATCH 13/94] fix generation/utils logic --- mindone/transformers/models/ernie4_5/modeling_ernie4_5.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py index 3965b0f54a..defcfb44a2 100644 --- a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -311,6 +311,9 @@ class Ernie4_5PreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False _supports_attention_backend = True + + # fixme generation uitls generate cache logic should be considered again + _supports_dynamic_input = True _can_record_outputs = { "hidden_states": Ernie4_5DecoderLayer, "attentions": Ernie4_5Attention, From b3334ac21ed07aab071bbbb8f0891869ca958f44 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 17:54:19 +0800 Subject: [PATCH 14/94] fix get_output_embedding override bug --- mindone/transformers/modeling_utils.py | 35 ------------------- .../models/ernie4_5/modeling_ernie4_5.py | 2 +- 2 files changed, 1 insertion(+), 36 deletions(-) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 80b7cddf0f..1c6b3d4e83 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -1182,41 +1182,6 @@ def _from_config(cls, config, **kwargs): return model - def get_input_embeddings(self) -> nn.Cell: - """ - Returns the model's input embeddings. - - Returns: - `nn.Cell`: A mindspore cell mapping vocabulary to hidden states. - """ - base_model = getattr(self, self.base_model_prefix, self) - if base_model is not self: - return base_model.get_input_embeddings() - else: - raise NotImplementedError - - def set_input_embeddings(self, value: nn.Cell): - """ - Set model's input embeddings. - - Args: - value (`nn.Cell`): A cell mapping vocabulary to hidden states. - """ - base_model = getattr(self, self.base_model_prefix, self) - if base_model is not self: - base_model.set_input_embeddings(value) - else: - raise NotImplementedError - - def get_output_embeddings(self) -> nn.Cell: - """ - Returns the model's output embeddings. - - Returns: - `nn.Cell`: A mindspore cell mapping hidden states to vocabulary. - """ - return None # Overwrite for models with output embeddings - def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: """ Check the availability of Flash Attention 2 for a given model. diff --git a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py index defcfb44a2..844868f4fc 100644 --- a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -302,7 +302,7 @@ class Ernie4_5PreTrainedModel(PreTrainedModel): # fixme check with PretrainedModel.__init__subclass__ config_class = Ernie4_5Config base_model_prefix = "model" - supports_gradient_checkpointing = True + supports_gradient_checkpointing = False _no_split_modules = ["Ernie4_5DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True From 833419cfecbeab561d382d715805b0af98359cae Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Thu, 21 Aug 2025 20:20:27 +0800 Subject: [PATCH 15/94] fix __init_subclass__ bug --- mindone/transformers/modeling_utils.py | 23 +++++++++++++++++-- .../models/ernie4_5/modeling_ernie4_5.py | 3 +-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 1c6b3d4e83..cf670cdecc 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -959,8 +959,27 @@ def framework(self) -> str: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - # fixme get_type_hints would encounter a problem during typing check - pass + # For BC we keep the original `config_class` definition in case + # there is a `config_class` attribute (e.g. remote code models), + # otherwise we derive it from the annotated `config` attribute. + + # defined in this particular subclass + child_annotation = cls.__dict__.get("__annotations__", {}).get("config", None) + child_attribute = cls.__dict__.get("config_class", None) + + # defined in the class (this subclass or any parent class) + full_annotation = cls.__dict__.get("config", None) + full_attribute = cls.config_class + + # priority (child class_config -> child annotation -> global class_config -> global annotation) + if child_attribute is not None: + cls.config_class = child_attribute + elif child_annotation is not None: + cls.config_class = child_annotation + elif full_attribute is not None: + cls.config_class = full_attribute + elif full_annotation is not None: + cls.config_class = full_annotation def __init__(self, config: PretrainedConfig, *inputs, **kwargs): super().__init__() diff --git a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py index 844868f4fc..0430438995 100644 --- a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -299,8 +299,7 @@ def construct( @auto_docstring class Ernie4_5PreTrainedModel(PreTrainedModel): - # fixme check with PretrainedModel.__init__subclass__ - config_class = Ernie4_5Config + config : Ernie4_5Config base_model_prefix = "model" supports_gradient_checkpointing = False _no_split_modules = ["Ernie4_5DecoderLayer"] From c532a9a5da62251111d0af4aceb832647124a604 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Fri, 22 Aug 2025 16:05:53 +0800 Subject: [PATCH 16/94] suplement checkpoint_conversion_mapping --- mindone/transformers/modeling_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index cf670cdecc..9786b79cdb 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -2347,6 +2347,13 @@ def from_pretrained( adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_name = kwargs.pop("adapter_name", "default") + key_mapping = kwargs.pop("key_mapping", None) + # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model + if key_mapping is None and any( + allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS + ): + key_mapping = cls._checkpoint_conversion_mapping + if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", @@ -2786,6 +2793,7 @@ def from_pretrained( sharded_metadata=sharded_metadata, dtype=mindspore_dtype, keep_in_fp32_modules=keep_in_fp32_modules, + key_mapping=key_mapping, ) if _adapter_model_path is not None: From 1ac2f7221919e2ef2d3063cfab3e83c547c37dcb Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Fri, 22 Aug 2025 16:19:39 +0800 Subject: [PATCH 17/94] feat(transformers): upgrade beam search to v4.54 --- .../transformers/generation/beam_search.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mindone/transformers/generation/beam_search.py b/mindone/transformers/generation/beam_search.py index 0cb3aad2ea..01a2c02932 100644 --- a/mindone/transformers/generation/beam_search.py +++ b/mindone/transformers/generation/beam_search.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from collections import UserDict -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np from transformers.generation.beam_constraints import Constraint, ConstraintListState @@ -44,7 +44,7 @@ Beam indices indicating to which beam hypothesis the `next_tokens` correspond. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): + eos_token_id (`Union[int, list[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. beam_indices (`ms.Tensor`, *optional*): Beam indices indicating to which beam hypothesis each token correspond. @@ -80,7 +80,7 @@ The beam indices indicating to which beam the `final_beam_tokens` shall be added. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): + eos_token_id (`Union[int, list[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. Return: @@ -106,7 +106,7 @@ def process( next_tokens: ms.Tensor, next_indices: ms.Tensor, **kwargs, - ) -> Tuple[ms.Tensor]: + ) -> tuple[ms.Tensor]: raise NotImplementedError("This is an abstract method.") @abstractmethod @@ -154,7 +154,7 @@ class BeamSearchScorer(BeamScorer): [`~transformers.BeamSearchScorer.finalize`]. num_beam_groups (`int`, *optional*, defaults to 1): Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. - See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + See [this paper](https://huggingface.co/papers/1610.02424.pdf) for more details. max_length (`int`, *optional*): The maximum length of the sequence to be generated. """ @@ -215,11 +215,11 @@ def process( next_tokens: ms.Tensor, next_indices: ms.Tensor, pad_token_id: Optional[Union[int, ms.Tensor]] = None, - eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None, + eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None, beam_indices: Optional[ms.Tensor] = None, group_index: Optional[int] = 0, decoder_prompt_len: Optional[int] = 0, - ) -> Dict[str, ms.Tensor]: + ) -> dict[str, ms.Tensor]: # add up to the length which the next_scores is calculated on (including decoder prompt) cur_len = input_ids.shape[-1] + 1 batch_size = len(self._beam_hyps) // self.num_beam_groups @@ -320,10 +320,10 @@ def finalize( final_beam_indices: ms.Tensor, max_length: int, pad_token_id: Optional[Union[int, ms.Tensor]] = None, - eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None, + eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None, beam_indices: Optional[ms.Tensor] = None, decoder_prompt_len: Optional[int] = 0, - ) -> Tuple[ms.Tensor]: + ) -> tuple[ms.Tensor]: batch_size = len(self._beam_hyps) // self.num_beam_groups if eos_token_id is not None and not isinstance(eos_token_id, ms.Tensor): @@ -421,7 +421,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): Batch Size of `input_ids` for which standard beam search decoding is run in parallel. num_beams (`int`): Number of beams for beam search. - constraints (`List[Constraint]`): + constraints (`list[Constraint]`): A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation output. For more information, the documentation of [`Constraint`] should be read. length_penalty (`float`, *optional*, defaults to 1.0): @@ -449,7 +449,7 @@ def __init__( self, batch_size: int, num_beams: int, - constraints: List[Constraint], + constraints: list[Constraint], length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[Union[bool, str]] = False, num_beam_hyps_to_keep: Optional[int] = 1, @@ -508,10 +508,10 @@ def process( next_indices: ms.Tensor, scores_for_all_vocab: ms.Tensor, pad_token_id: Optional[Union[int, ms.Tensor]] = None, - eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None, + eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None, beam_indices: Optional[ms.Tensor] = None, decoder_prompt_len: Optional[int] = 0, - ) -> Tuple[ms.Tensor]: + ) -> tuple[ms.Tensor]: r""" Args: input_ids (`ms.Tensor` of shape `(batch_size * num_beams, sequence_length)`): @@ -531,7 +531,7 @@ def process( The scores of all tokens in the vocabulary for each of the beam hypotheses. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): + eos_token_id (`Union[int, list[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. beam_indices (`ms.Tensor`, *optional*): Beam indices indicating to which beam hypothesis each token correspond. @@ -807,10 +807,10 @@ def finalize( final_beam_indices: ms.Tensor, max_length: int, pad_token_id: Optional[Union[int, ms.Tensor]] = None, - eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None, + eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None, beam_indices: Optional[ms.Tensor] = None, decoder_prompt_len: Optional[int] = 0, - ) -> Tuple[ms.Tensor]: + ) -> tuple[ms.Tensor]: batch_size = len(self._beam_hyps) if eos_token_id is not None and not isinstance(eos_token_id, ms.Tensor): From 375e6ab67b77077ae766750336dc260425d7b8bd Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 25 Aug 2025 09:40:13 +0800 Subject: [PATCH 18/94] feat(transformers): upgrade candidate_generator to v4.54 --- .../generation/candidate_generator.py | 217 ++++++++++-------- 1 file changed, 118 insertions(+), 99 deletions(-) diff --git a/mindone/transformers/generation/candidate_generator.py b/mindone/transformers/generation/candidate_generator.py index f5c554ca07..3095f24a23 100644 --- a/mindone/transformers/generation/candidate_generator.py +++ b/mindone/transformers/generation/candidate_generator.py @@ -24,10 +24,10 @@ from transformers import is_sklearn_available import mindspore as ms -from mindspore import mint +from mindspore import mint, nn from mindspore import numpy as mnp -from ..cache_utils import DynamicCache +from ..mindspore_utils import prune_linear_layer if is_sklearn_available(): from sklearn.metrics import roc_curve @@ -286,9 +286,7 @@ def _update_past_and_masks(self, input_ids: ms.Tensor, remove_from_pkv: int = 0, has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens - ) + self.assistant_kwargs["past_key_values"] = self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens) self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder ) @@ -603,13 +601,68 @@ def _process_assistant_outputs( return new_target_ids +class _PruneReindexingLMHead(nn.Cell): + """ + A class to prune and reindex the language model head. + + This class prunes the language model head to only include the specified token IDs and reindexes the logits + to map back to the original vocabulary. + + Args: + original_lm_head (nn.Module): The original language model head. + token_ids (list[int]): The list of token IDs to keep. + """ + + def __init__(self, original_lm_head, assistant_overlap_token_ids): + super().__init__() + self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to( + original_lm_head.weight.dtype + ) + + def construct(self, hidden_states): + pruned_logits = self.pruned_lm_head(hidden_states) + return pruned_logits + + +class _MapInputEmbedding(nn.Cell): + def __init__(self, original_embedding: mint.nn.Embedding, assistant_overlap_token_ids): + """ + Wraps an existing embedding layer and remaps token IDs before lookup. + + Args: + original_embedding (mint.nn.Embedding): Pre-trained or existing embedding layer. + assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs. + Example: {old_id: new_id} + """ + super().__init__() + self.original_embedding = original_embedding + self.weight = original_embedding.weight + self.assistant_overlap_token_ids = assistant_overlap_token_ids + self.map = False + + def construct(self, input_ids: ms.Tensor) -> ms.Tensor: + """ + Args: + input_ids (ms.Tensor): Tensor of token IDs (batch_size, seq_len). + + Returns: + ms.Tensor: Corresponding input embeddings. + """ + if self.map: + # Get the last item from input_ids + my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0) + else: + self.map = True + my_input_ids = input_ids + + return self.original_embedding(my_input_ids) class AssistantToTargetTranslator: """ Translates token ids and logits between assistant and target model vocabularies. This class is used to handle vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding, as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies" - (https://www.arxiv.org/abs/2502.05202). + (https://huggingface.co/papers/2502.05202). It maintains mappings between the two vocabularies and handles token/logit conversion. Args: @@ -617,8 +670,12 @@ class AssistantToTargetTranslator: The tokenizer used by the target (main) model. assistant_tokenizer (`PreTrainedTokenizerBase`): The tokenizer used by the assistant model. - target_vocab_size (`int`, *optional*): + target_vocab_size (`int`): The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. + assistant_model (Optional[PreTrainedModel], optional): The assistant model to be used. Defaults to None for backward compatibility. + assistant_prune_lm_head (bool): Whether to prune the assistant model's language model + head to match the target vocabulary. This is only applicable if `assistant_model` is provided. + Defaults to False for backward compatibility. """ FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. @@ -629,56 +686,49 @@ def __init__( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() + assistant_model: Optional["PreTrainedModel"] = None, + assistant_prune_lm_head: bool = False, ): - self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer - self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer + self._target_tokenizer: PreTrainedTokenizerBase = target_tokenizer + self._assistant_tokenizer: PreTrainedTokenizerBase = assistant_tokenizer self.target_vocab_size: int = target_vocab_size - ( - self._assistant_to_target_input_ids, - self.target_to_assistant_input_ids, - ) = self._get_assistant_to_target_input_ids() + self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( + self._get_assistant_to_target_input_ids() + ) self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: Optional[LogitsProcessorList] = None + self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None if len(self._suppress_input_ids) > 0: - # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab - self.logits_processors = LogitsProcessorList( - [SuppressTokensLogitsProcessor(self._get_suppress_input_ids())] - ) + # the assistant vocab is not a subset of the target vocab + if self.assistant_prune_lm_head: + self.assistant_overlap_token_ids = ms.tensor( + list(self.target_to_assistant_input_ids.values()), + dtype=ms.int64, + ) + original_lm_head = assistant_model.get_output_embeddings() + pruned_lm_head = _PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids) + del original_lm_head + assistant_model.set_output_embeddings(pruned_lm_head) + + original_input_embeddings = assistant_model.get_input_embeddings() + map_input_embeddings = _MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids) + del original_input_embeddings + assistant_model.set_input_embeddings(map_input_embeddings) + self.map_input_embeddings = map_input_embeddings + else: + self.logits_processors = LogitsProcessorList( + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids())] + ) - def _get_assistant_to_target_input_ids(self): - target_vocab = self._target_tokenizer.get_vocab() - assistant_vocab = self._assistant_tokenizer.get_vocab() - - space_str = " " - target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"] - if len(target_space_ids) > 0: - target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0] - - assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"] - if len(assistant_space_ids) > 0: - assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0] - - if target_space_sign != assistant_space_sign: - # If the assistant tokenizer has a different space sign than the target tokenizer, - # we need to replace the assistant space sign with the target space sign in the assistant_vocab. - assistant_vocab = { - ( - tok.replace(assistant_space_sign, target_space_sign, 1) - if tok.startswith(assistant_space_sign) - else tok - ): idx - for tok, idx in assistant_vocab.items() - } - - max_assistant_index = max(assistant_vocab.values()) - assistant_to_target_input_ids = mint.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int) - target_to_assistant_input_ids: dict[int, int] = {} - for tok, assistant_id in assistant_vocab.items(): - target_id = target_vocab.get(tok) - if target_id is not None: - assistant_to_target_input_ids[assistant_id] = target_id - target_to_assistant_input_ids[target_id] = assistant_id - return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids + def unmap_input_ids(self): + """ + Disables the mapping of input ids despite the assistant pruning for the language model head being enabled. + + This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. By disabling the mapping, it ensures that the input ids are processed correctly without remapping. + + """ + if self.assistant_prune_lm_head: + self.map_input_embeddings.map = False def _get_suppress_input_ids(self) -> list[int]: """ @@ -697,7 +747,12 @@ def get_target_ids(self, assistant_input_ids, target_input_ids, assistant_candid if num_new_tokens == 0: return target_input_ids else: - transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] + # Get last `num_new_tokens` candidate IDs + last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:] + if self.assistant_prune_lm_head: + # Map assistant IDs -> target input IDs + last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids] + transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids] return mint.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: ms.Tensor) -> ms.Tensor: @@ -713,8 +768,11 @@ def get_target_logits(self, assistant_logits: ms.Tensor) -> ms.Tensor: target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] - target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] - + if self.assistant_prune_lm_head: + target_logits[..., target_logits_supported_indices] = assistant_logits + else: + valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] + target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] return target_logits @@ -732,7 +790,8 @@ def get_translator( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int, - assistant_model_device: str = "cpu", + assistant_model: Optional["PreTrainedModel"] = None, + assistant_prune_lm_head: bool = False, ) -> AssistantToTargetTranslator: assistant_dict = cls._cache.get(target_tokenizer) if assistant_dict is None: @@ -742,7 +801,7 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: mapping = AssistantToTargetTranslator( - target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device + target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model, assistant_prune_lm_head ) assistant_dict[assistant_tokenizer] = mapping @@ -879,7 +938,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: ms.Tensor) -> ms.Tensor self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] assistant_input_ids = mint.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) assistant_input_ids = assistant_input_ids.to(dtype=ms.int64) - + self._atm_translator.unmap_input_ids() return assistant_input_ids, len(assistant_new_ids[0]) @@ -925,7 +984,7 @@ def get_candidates(self, input_ids: ms.Tensor) -> tuple[ms.Tensor, Optional[ms.T Return: `ms.Tensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. """ - input_length = input_ids.size(1) + input_length = input_ids.shape[1] # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. if self.max_length == input_length + 1: @@ -940,6 +999,7 @@ def get_candidates(self, input_ids: ms.Tensor) -> tuple[ms.Tensor, Optional[ms.T # Convert ngram to a tensor for comparison ngram_tensor = input_ids[0, -ngram_size:] + # Find where the windows match the ngram # Find where the windows match the ngram matches = (windows == ngram_tensor).all(dim=2) @@ -1051,47 +1111,6 @@ def get_candidates(self, input_ids: ms.Tensor) -> tuple[ms.Tensor, Optional[ms.T return candidate_ids, candidate_logits -def _crop_past_key_values(model, past_key_values, max_length): - """Crops the past key values up to a certain maximum length.""" - new_past = [] - if model.config.is_encoder_decoder: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :max_length, :], - past_key_values[idx][1][:, :, :max_length, :], - past_key_values[idx][2], - past_key_values[idx][3], - ) - ) - past_key_values = tuple(new_past) - # gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model - elif "gptbigcode" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() - ): - if model.config.multi_query: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :max_length, :] - else: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :, :max_length, :] - elif isinstance(past_key_values, DynamicCache): - past_key_values.crop(max_length) - elif past_key_values is not None: - for idx in range(len(past_key_values)): - if past_key_values[idx] != ([], []): - new_past.append( - ( - past_key_values[idx][0][:, :, :max_length, :], - past_key_values[idx][1][:, :, :max_length, :], - ) - ) - else: - new_past.append((past_key_values[idx][0], past_key_values[idx][1])) - past_key_values = tuple(new_past) - return past_key_values - - def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_encoder_decoder: bool) -> dict[str, Any]: """Expands or crops the model's mask for decoding purposes, to the defined length""" From 25033c102275eb65e87729c13809e0883f4d0b9f Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 25 Aug 2025 10:03:01 +0800 Subject: [PATCH 19/94] feat(transformers): upgrade logits_process/stopping_criteria to v4.54 --- .../transformers/generation/logits_process.py | 396 +++++++++++++++--- .../generation/stopping_criteria.py | 21 +- 2 files changed, 351 insertions(+), 66 deletions(-) diff --git a/mindone/transformers/generation/logits_process.py b/mindone/transformers/generation/logits_process.py index afc98e331e..7bcc5ca7b8 100644 --- a/mindone/transformers/generation/logits_process.py +++ b/mindone/transformers/generation/logits_process.py @@ -91,7 +91,7 @@ def __call__( scores (`Union[ms.Tensor, np.ndarray]` of shape `(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search - kwargs (`Dict[str, Any]`, *optional*): + kwargs (`dict[str, Any]`, *optional*): Additional kwargs that are specific to a logits processor. Return: @@ -122,11 +122,11 @@ class MinLengthLogitsProcessor(LogitsProcessor): Args: min_length (`int`): The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`Union[int, List[int], ms.Tensor, np.ndarray]`): + eos_token_id (`Union[int, list[int], ms.Tensor, np.ndarray]`): The id(s) of the *end-of-sequence* token. """ - def __init__(self, min_length: int, eos_token_id: Union[int, List[int], ms.Tensor, np.ndarray], **ignore): + def __init__(self, min_length: int, eos_token_id: Union[int, list[int], ms.Tensor, np.ndarray], **ignore): if not isinstance(min_length, int) or min_length < 0: raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}") @@ -175,12 +175,12 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): input length. min_new_tokens (`int`): The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`Union[int, List[int], ms.Tensor]`): + eos_token_id (`Union[int, list[int], ms.Tensor]`): The id(s) of the *end-of-sequence* token. """ def __init__( - self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], ms.Tensor], **ignore + self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, list[int], ms.Tensor], **ignore ): for arg_name, arg_value in [ ("prompt_length_to_skip", prompt_length_to_skip), @@ -269,9 +269,10 @@ def __call__( class RepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at - most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt. + most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt + by default. - In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around + In the original [paper](https://huggingface.co/papers/papers/1909.05858), the authors suggest the use of a penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly. @@ -280,21 +281,69 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated tokens. Between 0.0 and 1.0 rewards previously generated tokens. + prompt_ignore_length (`int`, *optional*): + The original input ids sequence length, which if provided, will not be used in the penalty calculation. """ - def __init__(self, penalty: float): + def __init__(self, penalty: float, prompt_ignore_length: Optional[int] = None): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + if prompt_ignore_length is not None and ( + not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0 + ): + raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}") + self.penalty = penalty + self.prompt_ignore_length = prompt_ignore_length + self.logits_indices = None + self.cumulative_seqlens_q = None + + def set_continuous_batching_context(self, logits_indices: ms.Tensor, cumulative_seqlens_q: ms.Tensor): + self.logits_indices = logits_indices + self.cumulative_seqlens_q = cumulative_seqlens_q @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor: - score = mint.gather(scores, 1, input_ids) + if self.prompt_ignore_length: + input_ids = input_ids[:, self.prompt_ignore_length:] + + if scores.dim() == 3: + if self.logits_indices is not None and self.cumulative_seqlens_q is not None: + batch_size, seq_len, vocab_size = scores.shape + last_positions = self.logits_indices + last_scores = scores[0, last_positions, :] + + # Prepare token mask + token_mask = mint.zeros_like(last_scores, dtype=ms.bool_) + cu_seq_lens = self.cumulative_seqlens_q + lengths = cu_seq_lens[1:] - cu_seq_lens[:-1] + seq_indices = mint.repeat_interleave(mint.arange(len(lengths)), lengths) + token_mask[seq_indices, input_ids] = True + + # Apply penalty + penalty_scores = mint.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty) + scores[0, last_positions, :] = mint.where(token_mask, penalty_scores, last_scores) + else: + batch_size, seq_len, vocab_size = scores.shape + last_scores = scores[:, -1, :] + token_mask = mint.zeros_like(last_scores, dtype=ms.bool_) + if input_ids.dim() == 1: + unique_tokens = mint.unique(input_ids) + token_mask.scatter_(1, unique_tokens.unsqueeze(0), True) + else: + token_mask.scatter_(1, input_ids, True) + # if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities + penalty_scores = mint.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty) + scores[:, -1, :] = mint.where(token_mask, penalty_scores, last_scores) + return scores + + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(1) + score = mint.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities score = mint.where(score < 0, score * self.penalty, score / self.penalty) - scores_processed = scores.scatter(1, input_ids, score) return scores_processed @@ -477,7 +526,7 @@ def __call__( class MinPLogitsWarper(LogitsProcessor): """ [`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the - probability of the most likely token. As a result, the filter becomes more agressive in the presence of + probability of the most likely token. As a result, the filter becomes more aggressive in the presence of high-probability tokens, which is a sign of a confident output that we shouldn't deviate from. Often used together with [`TemperatureLogitsWarper`]. Used as an alternative to [`TopPLogitsWarper`] and @@ -558,7 +607,7 @@ class TypicalLogitsWarper(LogitsProcessor): whose log probability is close to the entropy of the token probability distribution. This means that the most likely tokens may be discarded in the process. - See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. + See [Typical Decoding for Natural Language Generation](https://huggingface.co/papers/2202.00666) for more information. Args: mass (`float`, *optional*, defaults to 0.9): @@ -644,7 +693,7 @@ class EpsilonLogitsWarper(LogitsProcessor): r""" [`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model - Desmoothing](https://arxiv.org/abs/2210.15191) for more information. + Desmoothing](https://huggingface.co/papers/2210.15191) for more information. Args: epsilon (`float`): @@ -714,7 +763,7 @@ class EtaLogitsWarper(LogitsProcessor): the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation - Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample` + Sampling as Language Model Desmoothing](https://huggingface.co/papers/2210.15191) for more information. Note: `do_sample` must be set to `True` for this `LogitsProcessor` to work. @@ -840,7 +889,7 @@ def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): def _calc_banned_ngram_tokens( ngram_size: int, prev_input_ids: ms.Tensor, num_hypos: int, cur_len: int -) -> List[Iterable[int]]: +) -> list[Iterable[int]]: """Copied from fairseq for no_repeat_ngram in beam_search""" if cur_len + 1 < ngram_size: # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet @@ -867,7 +916,7 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York might lead to undesirable outcomes where the city's name appears only once in the entire text. - [Reference](https://huggingface.co/blog/how-to-generate) + [Reference](https://huggingface.co/papers/blog/how-to-generate) @@ -990,12 +1039,12 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): At a token-level, biasing a word is different from biasing a word with a space before it. If you want to bias "foo" mid-sentence, you'll likely want to add a prefix space and bias " foo" instead. Check the tokenizer section - of our NLP course to find out why: https://huggingface.co/learn/nlp-course/chapter2/4?fw=pt + of our NLP course to find out why: https://huggingface.co/papers/learn/nlp-course/chapter2/4?fw=pt Args: - sequence_bias (`List[List[Union[List[int], float]]]`): + sequence_bias (`list[list[Union[list[int], float]]]`): List of lists that maps a sequence of tokens to its bias term (e.g. `[[[10, 45], -2.0], [[64], -7.5]]`). Positive biases increase the odds of the sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias @@ -1005,7 +1054,8 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): Examples: ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import AutoModelForCausalLM >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") @@ -1044,13 +1094,13 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, sequence_bias: List[List[Union[List[int], float]]]): + def __init__(self, sequence_bias: list[list[Union[list[int], float]]]): self.sequence_bias = sequence_bias self._validate_arguments() self._convert_list_arguments_into_dict() # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size - # is infered in the first usage, which inhibits initializing here) + # is inferred in the first usage, which inhibits initializing here) self.length_1_bias = None self.prepared_bias_variables = False @@ -1106,9 +1156,13 @@ def _prepare_bias_variables(self, scores: ms.Tensor): # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied # with simpler logic. self.length_1_bias = mint.zeros((vocabulary_size,), dtype=ms.float32) + # Extract single-token sequences and their biases + single_token_ids = [] + single_token_biases = [] for sequence_ids, bias in self.sequence_bias.items(): if len(sequence_ids) == 1: - self.length_1_bias[sequence_ids[-1]] = bias + single_token_ids.append(sequence_ids[0]) + single_token_biases.append(bias) self.prepared_bias_variables = True @@ -1166,14 +1220,14 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor): `add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more - [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers). + [here](https://huggingface.co/papers/docs/tokenizers/api/pre-tokenizers). Args: - bad_words_ids (`List[List[int]]`): + bad_words_ids (`list[list[int]]`): List of list of token ids that are not allowed to be generated. - eos_token_id (`Union[int, List[int], ms.Tensor]`, *optional*): + eos_token_id (`Union[int, list[int], ms.Tensor]`, *optional*): The id(s) of the *end-of-sequence* token. Examples: @@ -1211,7 +1265,7 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor): ``` """ - def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None): + def __init__(self, bad_words_ids: list[list[int]], eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None): self.bad_word_ids = bad_words_ids self._validate_arguments() @@ -1222,8 +1276,9 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[ eos_token_id = [eos_token_id] eos_token_id = ms.tensor(eos_token_id) + eos_token_id_list = eos_token_id.tolist() # convert to python list before bad_words_ids = list( - filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids) + filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id_list), bad_words_ids) ) # Forbidding a sequence is equivalent to setting its bias to -inf @@ -1248,17 +1303,17 @@ def _validate_arguments(self): class PrefixConstrainedLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained - generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information. + generation. See [Autoregressive Entity Retrieval](https://huggingface.co/papers/2010.00904) for more information. Args: - prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], List[int]]`): + prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], list[int]]`): This function constraints the beam search to allowed tokens only at each step. This function takes 2 arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID `batch_id`. """ - def __init__(self, prefix_allowed_tokens_fn: Callable[[int, ms.Tensor], List[int]], num_beams: int): + def __init__(self, prefix_allowed_tokens_fn: Callable[[int, ms.Tensor], list[int]], num_beams: int): self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn self._num_beams = num_beams @@ -1326,7 +1381,8 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): Examples: ```python - >>> from mindone.transformers import AutoTokenizer, AutoModelForSeq2SeqLM + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import AutoModelForSeq2SeqLM >>> # Initialize the model and tokenizer >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") @@ -1491,7 +1547,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): Args: max_length (`int`): The maximum length of the sequence to be generated. - eos_token_id (`Union[int, List[int], ms.Tensor]`): + eos_token_id (`Union[int, list[int], ms.Tensor]`): The id(s) of the *end-of-sequence* token. Examples: @@ -1516,7 +1572,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, max_length: int, eos_token_id: Union[int, List[int], ms.Tensor]): + def __init__(self, max_length: int, eos_token_id: Union[int, list[int], ms.Tensor]): self.max_length = max_length if not isinstance(eos_token_id, ms.Tensor): @@ -1571,7 +1627,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): exponential_decay_length_penalty (`tuple(int, float)`): This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay - eos_token_id (`Union[int, List[int], ms.Tensor]`): + eos_token_id (`Union[int, list[int], ms.Tensor]`): The id(s) of the *end-of-sequence* token. input_ids_seq_length (`int`): The length of the input sequence. @@ -1632,7 +1688,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): def __init__( self, exponential_decay_length_penalty: tuple[int, float], - eos_token_id: Union[int, List[int], ms.Tensor], + eos_token_id: Union[int, list[int], ms.Tensor], input_ids_seq_length: int, ): self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length @@ -1691,7 +1747,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not generated at the beginning. Originally created for - [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper). + [Whisper](https://huggingface.co/papers/docs/transformers/model_doc/whisper). Examples: @@ -1745,7 +1801,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): r""" This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they are not generated. Originally created for - [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper). + [Whisper](https://huggingface.co/papers/docs/transformers/model_doc/whisper). Examples: @@ -1797,7 +1853,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): potential tokens. - See [the paper](https://arxiv.org/abs/2212.04356) for more information. + See [the paper](https://huggingface.co/papers/2212.04356) for more information. Args: generate_config (`GenerateConfig`): @@ -1814,7 +1870,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): Examples: ``` python - >>> from mindone.transformers import AutoProcessor, WhisperForConditionalGeneration, GenerationConfig + >>> from mindone.transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from transformers import GenerationConfig >>> from datasets import load_dataset >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") @@ -1858,11 +1915,13 @@ def __init__( if _detect_timestamp_from_logprob is not None else getattr(generate_config, "_detect_timestamp_from_logprob", True) ) - - num_forced_ids = ( - len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 - ) - self.begin_index = begin_index or (num_forced_ids + 1) + self.begin_index = begin_index + if begin_index is None: + raise ValueError( + "`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` " + "must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` " + "was `len(generate_config.forced_decoder_ids)`" + ) self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50 @@ -1947,6 +2006,10 @@ def set_inputs(self, inputs): self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs} self.inputs["input_features"] = self.inputs.pop("inputs") + # Whisper encoder-decoder does not accept the input_ids as input + if "input_ids" not in inspect.signature(self.model.forward).parameters: + self.inputs.pop("input_ids", None) + @property def no_speech_prob(self): return self._no_speech_prob @@ -1985,12 +2048,12 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. - See [the paper](https://arxiv.org/abs/2306.05284) for more information. + See [the paper](https://huggingface.co/papers/2306.05284) for more information. This logits processor is exclusively compatible with - [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) + [MusicGen](https://huggingface.co/papers/docs/transformers/main/en/model_doc/musicgen) @@ -2049,7 +2112,7 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor): This logits processor is exclusively compatible with - [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation + [Bark](https://huggingface.co/papers/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation for examples. @@ -2093,7 +2156,7 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch. - See [the paper](https://arxiv.org/abs/2306.17806) for more information. + See [the paper](https://huggingface.co/papers/2306.17806) for more information. Args: guidance_scale (`float`): @@ -2116,7 +2179,8 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): Examples: ```python - >>> from mindone.transformers import AutoTokenizer, AutoModelForCausalLM + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import AutoModelForCausalLM >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") @@ -2211,18 +2275,18 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): This logits processor is exclusively compatible with - [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark). See the model documentation for examples. + [Bark](https://huggingface.co/papers/docs/transformers/en/model_doc/bark). See the model documentation for examples. Args: - eos_token_id (`Union[int, List[int], ms.Tensor]`): + eos_token_id (`Union[int, list[int], ms.Tensor]`): The id(s) of the *end-of-sequence* token. min_eos_p (`float`, *optional*): Minimum end of speech threshold. """ - def __init__(self, eos_token_id: Union[int, List[int], ms.Tensor], min_eos_p: float): + def __init__(self, eos_token_id: Union[int, list[int], ms.Tensor], min_eos_p: float): if not isinstance(eos_token_id, ms.Tensor): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -2262,7 +2326,7 @@ class WatermarkLogitsProcessor(LogitsProcessor): The text generated by this `LogitsProcessor` can be detected using `WatermarkDetector`. See [`~WatermarkDetector.__call__`] for details, - See [the paper](https://arxiv.org/abs/2306.04634) for more information. + See [the paper](https://huggingface.co/papers/2306.04634) for more information. Args: vocab_size (`int`): @@ -2457,7 +2521,7 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor): Args: ngram_len (`int`): Ngram length. - keys (`List[int]`): + keys (`list[int]`): A sequence of watermarking keys, one for each depth. sampling_table_size (`int`): Size of the sampling table. @@ -2496,7 +2560,7 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor): def __init__( self, ngram_len: int, - keys: List[int], + keys: list[int], sampling_table_size: int, sampling_table_seed: int, context_history_size: int, @@ -2668,7 +2732,7 @@ def compute_ngram_keys(self, ngrams: ms.Tensor) -> ms.Tensor: ngram keys (batch_size, num_ngrams, depth). """ if len(ngrams.shape) != 3: - raise ValueError("Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but" f" is {ngrams.shape}") + raise ValueError(f"Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but is {ngrams.shape}") if ngrams.shape[2] != self.ngram_len: raise ValueError( "Ngrams should be of shape (batch_size, num_ngrams, ngram_len)," @@ -2856,3 +2920,223 @@ def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) -> The expected mean g-value for watermarked text. """ return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size)) + + +class DiaClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] for classifier free guidance (CFG). Similar to the original + `ClassifierFreeGuidanceLogitsProcessor` with some modifications on the overall + calculation, e.g. conditioned logits centered, and an additional top k selection + option. + + + + This logits processor is exclusively compatible with + [Dia](https://huggingface.co/docs/transformers/main/en/model_doc/dia) + + + + Args: + guidance_scale (float): + The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + Higher guidance scale encourages the model to generate samples that are more closely linked to the input + prompt, usually at the expense of poorer quality. + guidance_top_k (int, *optional*): + The number of highest probability vocabulary tokens to keep for top-k-filtering. However, we do not keep + the logits of the combined CFG output, but the conditioned output only. + """ + + def __init__(self, guidance_scale: float, guidance_top_k: Optional[int] = None): + if guidance_scale > 1: + self.guidance_scale = guidance_scale + else: + raise ValueError( + "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale " + f"{guidance_scale}." + ) + + self.guidance_top_k = guidance_top_k + if self.guidance_top_k is not None and self.guidance_top_k < 1: + raise ValueError( + f"`guidance_top_k` has to be a strictly positive integer if given, but is {self.guidance_top_k}" + ) + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor: + # simple check to make sure we have compatible batch sizes between our + # logits scores (cond + uncond) and input ids (cond only) + if scores.shape[0] != 2 * input_ids.shape[0]: + raise ValueError( + f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to " + f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got " + f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids." + ) + # Base CFG with center on cond_logits + unguided_bsz = scores.shape[0] // 2 + cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0) + scores_processed = cond_logits + (cond_logits - uncond_logits) * self.guidance_scale + + # Optional CFG top k filtering + if self.guidance_top_k is not None: + # Create top k based on the combined CFG output + _, top_k_indices = mint.topk(scores_processed, k=self.guidance_top_k, dim=-1) + top_k_mask = mint.ones_like(scores_processed, dtype=ms.bool_) + top_k_mask = top_k_mask.scatter(dim=-1, index=top_k_indices, value=False) + # Only return conditioned logits with top k + scores_processed = cond_logits.masked_fill(top_k_mask, -float("inf")) + + return scores_processed + + +class DiaEOSChannelFilterLogitsProcessor(LogitsProcessor): + r"""Specialized processor that ensures certain properties around EOS sampling: + 1. Only channel 0 can generate EOS + 2. If channel 0 has EOS with highest logit, it will be the only candidate + 3. If channel 0 has EOS not with highest logit, it will be suppressed + + 2. and 3. are especially important in contexts where we allow sampling to guarantee the + respective tokens to be (not) sampled. + + + + This logits processor is exclusively compatible with + [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia). + + + + Args: + num_channels (`int`): + Number of audio codebooks. Simplifies access to the first channel on the logits. + eos_token_id (`int`): + The id of *end-of-sequence* token. + """ + + def __init__(self, num_channels: int, eos_token_id: int): + if num_channels < 1: + raise ValueError(f"Audio codebooks need at least one channel, but found {num_channels} channels.") + if eos_token_id < 1: + raise ValueError(f"Expected `eos_token_id` to be a positive integer, found {eos_token_id} instead.") + + self.num_channels = num_channels + self.eos_id = eos_token_id + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor: + # Reshape for easier channel indexing [B, C, V] + scores = scores.reshape(-1, self.num_channels, scores.shape[-1]) + + # EOS filter + # 1. Condition: Only the first channel can generate the EOS token + # Side condition of disabling generation of special tokens (e.g. audio pad, bos, ...) + # (Assumes them to be greater than audio eos token position) + scores[:, 1:, self.eos_id :] = mint.full_like( + scores[:, 1:, self.eos_id :], + fill_value=-float("inf"), + ) + scores[:, 0, self.eos_id + 1 :] = mint.full_like( + scores[:, 0, self.eos_id + 1 :], + fill_value=-float("inf"), + ) + + # 2+3 Conditions: Force/Suppress EOS if (not) highest logit + # Reshape back to original shape + scores = scores.view(-1, scores.shape[-1]) + + # Sample highest tokens + top_logit_indices = mint.argmax(scores, dim=-1) + + # 2. Force EOS + eos_highest_mask = top_logit_indices == self.eos_id + mask_eos_highest = mint.zeros_like(scores, dtype=ms.bool_) + mask_eos_highest[eos_highest_mask, : self.eos_id] = True + scores = scores.masked_fill(mask_eos_highest, -float("inf")) + + # 3. Suppress EOS + eos_not_highest_mask = top_logit_indices != self.eos_id + mask_eos_unless_highest = mint.zeros_like(scores, dtype=ms.bool_) + mask_eos_unless_highest[eos_not_highest_mask, self.eos_id] = True + scores = scores.masked_fill(mask_eos_unless_highest, -float("inf")) + + return scores + + +class DiaEOSDelayPatternLogitsProcessor(LogitsProcessor): + r"""Special logits processor to handle the generation of the EOS token in Dia. + This is due to the fact that Dia does not allow the generation of EOS in all + channels except the first channel (C0). + + Hence, based on the delay pattern, an EOS is forced after the respective delays + in the channels. For example, if the delay pattern is [0, 2, 3, 4]: + + s s+1 s+2 s+3 s+4 s+5 ... + | | | | | | + C0: EOS PAD PAD PAD PAD PAD ... + C1: x x EOS PAD PAD PAD ... + C2: x x x EOS PAD PAD ... + C3: x x x x EOS PAD ... + + If the first channel generated EOS at step s, channels Cx are forced to generate + theirs at the respective delays (s+2, s+3, s+4). Subsequent padding tokens are + handled by the `EosTokenCriteria` when an EOS has been detected. + + + + This logits processor is exclusively compatible with + [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia). + + + + Args: + delay_pattern (`List[int]`): + The delays per channel in the audio codebooks. + eos_token_id (`int`): + The id of *end-of-sequence* token. + max_generation_len (`int`): + The max sequence length that can be generated. + device (`str`, *optional*, defaults to `"cpu"`): + The device to allocate the tensors on. + """ + + def __init__(self, delay_pattern: list[int], eos_token_id: int, max_generation_len: int): + self.num_channels = len(delay_pattern) + # Update during first iteration + self.active_batches = None + self.delay_pattern = ms.tensor(delay_pattern, dtype=ms.int32)[None, :] + self.eos_token_id = eos_token_id + self.max_generation_len = max_generation_len - max(delay_pattern) - 1 + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor: + # Reshape for easier channel indexing [B, C, V] + scores = scores.reshape(-1, self.num_channels, scores.shape[-1]) + + # Initialize / expand values on first iteration + if self.active_batches is None: + self.delay_pattern = self.delay_pattern.repeat(scores.shape[0], 1) + self.active_batches = mint.zeros(size=(scores.shape[0],), dtype=ms.bool_) + + # Check if eos has been generated in any batch + channel_generated_eos = mint.argmax(scores, dim=-1)[:, 0] == self.eos_token_id + # Check if max len has been reached + reached_max_len = input_ids.shape[1] == self.max_generation_len + + # Update active batches + self.active_batches |= channel_generated_eos + self.active_batches |= reached_max_len + + # Find channels that need to force eos + forced_eos_channels = self.active_batches[:, None] & (self.delay_pattern == 0) + # Use indexing to avoid issues on all `False` by having empty tensors in that case + idx_bsz, idx_channel = forced_eos_channels.nonzero(as_tuple=True) + + # Force eos if delay is kicking in + scores[idx_bsz, idx_channel, :] = -float("inf") + scores[idx_bsz, idx_channel, self.eos_token_id] = 0.0 + + # Reshape back to [B * C, V] + scores = scores.reshape(-1, scores.shape[-1]) + + # Update amount of delay left for each channel + self.delay_pattern -= self.active_batches[:, None].int() + + return scores diff --git a/mindone/transformers/generation/stopping_criteria.py b/mindone/transformers/generation/stopping_criteria.py index 7d5b78365d..8489f7e5c4 100644 --- a/mindone/transformers/generation/stopping_criteria.py +++ b/mindone/transformers/generation/stopping_criteria.py @@ -24,14 +24,14 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" Args: - input_ids (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, sequence_length)`): + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - scores (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, config.vocab_size)`): + scores (`ms.Tensor` of shape `(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. @@ -39,8 +39,9 @@ Additional stopping criteria specific kwargs. Return: - `Union[ms.Tensor, numpy.ndarray]`. (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, 1)`), where `True` indicates we stop generation - for a particular row, `True` indicates we should continue. + `ms.Tensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`): + `True` indicates we stop generation for a particular row. + `False` indicates we should continue. """ @@ -77,7 +78,7 @@ def __init__(self, max_length: int, max_position_embeddings: Optional[int] = Non def __call__( self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray], **kwargs ) -> Union[ms.Tensor, np.ndarray]: - cur_len = input_ids.shape[-1] + cur_len = input_ids.shape[1] is_done = cur_len >= self.max_length if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: logger.warning_once( @@ -226,7 +227,7 @@ class StopStringCriteria(StoppingCriteria): Args: tokenizer (`PreTrainedTokenizer`): The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences) - stop_strings (`Union[str, List[str]]`): + stop_strings (`Union[str, list[str]]`): A list of strings that should end generation. If a string is passed, it will be treated like a list with a single element. @@ -256,7 +257,7 @@ class StopStringCriteria(StoppingCriteria): ``` """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]): + def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, list[str]]): if isinstance(stop_strings, str): stop_strings = [stop_strings] self.stop_strings: tuple[str, ...] = tuple(stop_strings) @@ -315,7 +316,7 @@ def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): @staticmethod def _stop_string_get_matching_positions( token_list, token_indices, stop_strings - ) -> tuple[dict[str, dict[str, List[int]]], dict[str, dict[str, List[int]]]]: + ) -> tuple[dict[str, dict[str, list[int]]], dict[str, dict[str, list[int]]]]: """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters @@ -470,11 +471,11 @@ class EosTokenCriteria(StoppingCriteria): By default, it uses the `model.generation_config.eos_token_id`. Args: - eos_token_id (`Union[int, List[int], ms.Tensor]`): + eos_token_id (`Union[int, list[int], ms.Tensor]`): The id(s) of the *end-of-sequence* token. """ - def __init__(self, eos_token_id: Union[int, List[int], ms.Tensor]): + def __init__(self, eos_token_id: Union[int, list[int], ms.Tensor]): # to list if not isinstance(eos_token_id, ms.Tensor): if isinstance(eos_token_id, int): From 252f4aa704f1220075cd0440577421d2bd30a5ae Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 25 Aug 2025 10:16:09 +0800 Subject: [PATCH 20/94] pre-commit --- mindone/transformers/__init__.py | 7 +- mindone/transformers/activations.py | 6 +- mindone/transformers/cache_utils.py | 173 +++--------------- .../generation/candidate_generator.py | 16 +- .../transformers/generation/logits_process.py | 8 +- .../generation/stopping_criteria.py | 2 +- mindone/transformers/generation/utils.py | 46 ++--- mindone/transformers/modeling_layers.py | 3 +- mindone/transformers/modeling_rope_utils.py | 12 +- mindone/transformers/modeling_utils.py | 53 +++--- .../transformers/models/aria/modeling_aria.py | 2 +- .../models/cohere2/modeling_cohere2.py | 2 +- .../transformers/models/ernie4_5/__init__.py | 2 +- .../models/ernie4_5/modeling_ernie4_5.py | 33 ++-- .../models/gemma/modeling_gemma.py | 2 +- .../transformers/models/glm/modeling_glm.py | 2 +- .../transformers/models/glm4/modeling_glm4.py | 3 +- .../models/glm4v/modeling_glm4v.py | 2 +- .../models/granite/modeling_granite.py | 2 +- .../models/helium/modeling_helium.py | 2 +- .../models/llama/modeling_llama.py | 2 +- .../models/mixtral/modeling_mixtral.py | 2 +- .../transformers/models/phi3/modeling_phi3.py | 2 +- .../models/qwen2/modeling_qwen2.py | 3 +- .../models/qwen3/modeling_qwen3.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- mindone/transformers/utils/generic.py | 7 +- 27 files changed, 142 insertions(+), 256 deletions(-) diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 70a06c1910..526898df15 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -26,6 +26,7 @@ from packaging import version from .cache_utils import * + # Feature Extractor from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .image_processing_base import ImageProcessingMixin @@ -274,11 +275,7 @@ DebertaV2PreTrainedModel, ) from .models.dpt import DPTForDepthEstimation -from .models.ernie4_5 import ( - Ernie4_5PreTrainedModel, - Ernie4_5Model, - Ernie4_5ForCausalLM -) +from .models.ernie4_5 import Ernie4_5ForCausalLM, Ernie4_5Model, Ernie4_5PreTrainedModel from .models.fuyu import FuyuForCausalLM, FuyuPreTrainedModel from .models.gemma import ( GemmaForCausalLM, diff --git a/mindone/transformers/activations.py b/mindone/transformers/activations.py index 691fb55947..92b1d8c746 100644 --- a/mindone/transformers/activations.py +++ b/mindone/transformers/activations.py @@ -19,7 +19,7 @@ from collections import OrderedDict import mindspore as ms -from mindspore import Tensor, nn, mint +from mindspore import Tensor, mint, nn class PytorchGELUTanh(nn.Cell): @@ -43,7 +43,9 @@ class NewGELUActivation(nn.Cell): def construct(self, input: Tensor) -> Tensor: return ( - 0.5 * input * (1.0 + mint.tanh(mint.sqrt(Tensor(2.0 / math.pi)) * (input + 0.044715 * mint.pow(input, 3.0)))) + 0.5 + * input + * (1.0 + mint.tanh(mint.sqrt(Tensor(2.0 / math.pi)) * (input + 0.044715 * mint.pow(input, 3.0)))) ).to(input.dtype) diff --git a/mindone/transformers/cache_utils.py b/mindone/transformers/cache_utils.py index b95ed93cb3..8f1b02d598 100644 --- a/mindone/transformers/cache_utils.py +++ b/mindone/transformers/cache_utils.py @@ -10,14 +10,14 @@ import os from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging import mindspore as ms -from mindspore import mint, nn, ops +from mindspore import mint, ops logger = logging.get_logger(__name__) @@ -149,6 +149,7 @@ def reset(past_key_values): return past_key_values + class CacheLayerMixin(ABC): """Base, abstract class for a single layer's cache.""" @@ -163,16 +164,20 @@ def update( key_states: ms.Tensor, value_states: ms.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[ms.Tensor, ms.Tensor]: ... + ) -> tuple[ms.Tensor, ms.Tensor]: + ... @abstractmethod - def get_seq_length(self, cache_position=None) -> int: ... + def get_seq_length(self, cache_position=None) -> int: + ... @abstractmethod - def get_max_cache_shape(self) -> int: ... + def get_max_cache_shape(self) -> int: + ... @abstractmethod - def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]: ... + def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]: + ... def reset(self) -> None: """Resets the cache values while preserving the objects""" @@ -186,6 +191,7 @@ def reorder_cache(self, beam_idx: ms.Tensor) -> tuple[ms.Tensor, ms.Tensor]: if self.values.numel(): self.values = self.values.index_select(0, beam_idx) + class DynamicLayer(CacheLayerMixin): """ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. @@ -716,9 +722,7 @@ def __init__( compute_dtype: ms.Type = ms.float16, ) -> None: """Initialize the quanto quantization processor.""" - super().__init__( - cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype - ) + super().__init__(cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype) raise NotImplementedError @@ -741,9 +745,7 @@ def __init__( compute_dtype: ms.Type = ms.float16, ) -> None: """Initialize the HQQ quantization processor.""" - super().__init__( - cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype - ) + super().__init__(cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype) raise NotImplementedError @@ -1008,7 +1010,7 @@ def value_cache(self) -> KeyValuesWrapper: ) return KeyValuesWrapper(self.layers, "values") - ### Wrappers for layer operations and properties ### + # Wrappers for layer operations and properties ### def get_max_cache_shape(self, layer_idx: int = 0) -> int: """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length.""" @@ -1128,6 +1130,7 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[ms.Tensor, ms.Tensor], . cache.update(key_states, value_states, layer_idx) return cache + class StaticCache(Cache): """ Static Cache class to be used with `mindspore.jit(model)`. @@ -1159,136 +1162,6 @@ def __init__(self, *args, **kwargs): super().__init__(layer_classes=StaticLayer, *args, **kwargs) -class SlidingWindowCache(StaticCache): - """ - Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - - indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) - - We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - batch_size (`int`): - The batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - dtype (`ms.Type`, *optional*, defaults to `ms.float32`): - The default `dtype` to use when initializing the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer - >>> from mindone.transformers import AutoModelForCausalLM, SlidingWindowCache - - >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - - >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - SlidingWindowCache() - ``` - """ - - is_sliding = True - is_compileable = True - - # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. - def __init__( - self, - config: PretrainedConfig, - batch_size: int = None, - max_cache_len: int = None, - dtype: ms.Type = ms.float32, - max_batch_size: Optional[int] = None, - ) -> None: - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - max_cache_len = min(config.sliding_window, max_cache_len) - super().__init__( - config=config, - batch_size=batch_size, - max_cache_len=max_cache_len, - dtype=dtype, - max_batch_size=max_batch_size, - ) - - def update( - self, - key_states: ms.Tensor, - value_states: ms.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[ms.Tensor]: - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) - if cache_position.shape[0] > self.max_cache_len: - k_out = key_states[:, :, -self.max_cache_len :, :] - v_out = value_states[:, :, -self.max_cache_len :, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = ops.ones(self.max_cache_len, dtype=ms.int32).cumsum(0) - cache_position = cache_position.clamp(0, self.max_cache_len - 1) - to_shift = cache_position >= self.max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len - - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx]) - self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx]) - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - - return k_out, v_out - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - def reset(self): - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx]) - self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx]) - - class SlidingWindowCache(Cache): """ Sliding Window Cache class to be used with `mindspore.jit` for models like Mistral that support sliding window attention. @@ -1370,6 +1243,7 @@ def __init__(self, config: PretrainedConfig, *args, **kwargs): layer_classes = [StaticLayer] * config.num_hidden_layers super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + class EncoderDecoderCache(Cache): """ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and @@ -1459,9 +1333,7 @@ def to_legacy_cache(self) -> tuple[tuple[ms.Tensor]]: return legacy_cache @classmethod - def from_legacy_cache( - cls, past_key_values: tuple[tuple[ms.Tensor, ms.Tensor], ...] - ) -> "EncoderDecoderCache": + def from_legacy_cache(cls, past_key_values: tuple[tuple[ms.Tensor, ms.Tensor], ...]) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" cache = cls( self_attention_cache=DynamicCache(), @@ -1504,8 +1376,7 @@ def reorder_cache(self, beam_idx: ms.Tensor): def check_dynamic_cache(self, method: str): if not ( - isinstance(self.self_attention_cache, DynamicCache) - and isinstance(self.cross_attention_cache, DynamicCache) + isinstance(self.self_attention_cache, DynamicCache) and isinstance(self.cross_attention_cache, DynamicCache) ): raise ValueError( f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " @@ -1554,6 +1425,7 @@ def get_max_cache_shape(self) -> int: def get_mask_sizes(self, cache_position: ms.Tensor, layer_idx: int) -> tuple[int, int]: return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) + class MambaCache: def __init__(self): raise NotImplementedError @@ -1563,6 +1435,7 @@ class OffloadedStaticCache(StaticCache): def __init__(self): raise NotImplementedError + def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]: """ Parse processor arguments from kwargs based on the processor class init signature. @@ -1653,6 +1526,7 @@ def parse_layer_args_from_model_config( } return {k: v for k, v in layer_args.items() if v is not None} + LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { "full_attention": StaticLayer, "sliding_attention": SlidingWindowLayer, @@ -1664,6 +1538,7 @@ def parse_layer_args_from_model_config( "hqq_quantized": HQQQuantizedCacheProcessor, } + class CacheConfig: """ Base class for cache configs @@ -1756,4 +1631,4 @@ def update(self, **kwargs): # Remove all the attributes that were updated, without modifying the input dict unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs \ No newline at end of file + return unused_kwargs diff --git a/mindone/transformers/generation/candidate_generator.py b/mindone/transformers/generation/candidate_generator.py index 3095f24a23..d8c42b0702 100644 --- a/mindone/transformers/generation/candidate_generator.py +++ b/mindone/transformers/generation/candidate_generator.py @@ -286,7 +286,9 @@ def _update_past_and_masks(self, input_ids: ms.Tensor, remove_from_pkv: int = 0, has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv - self.assistant_kwargs["past_key_values"] = self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens) + self.assistant_kwargs["past_key_values"] = self.assistant_kwargs["past_key_values"].crop( + new_cache_size - num_added_tokens + ) self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder ) @@ -601,6 +603,7 @@ def _process_assistant_outputs( return new_target_ids + class _PruneReindexingLMHead(nn.Cell): """ A class to prune and reindex the language model head. @@ -657,6 +660,7 @@ def construct(self, input_ids: ms.Tensor) -> ms.Tensor: return self.original_embedding(my_input_ids) + class AssistantToTargetTranslator: """ Translates token ids and logits between assistant and target model vocabularies. This class is used to handle @@ -692,9 +696,10 @@ def __init__( self._target_tokenizer: PreTrainedTokenizerBase = target_tokenizer self._assistant_tokenizer: PreTrainedTokenizerBase = assistant_tokenizer self.target_vocab_size: int = target_vocab_size - self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( - self._get_assistant_to_target_input_ids() - ) + ( + self._assistant_to_target_input_ids, + self.target_to_assistant_input_ids, + ) = self._get_assistant_to_target_input_ids() self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: Optional[LogitsProcessorList] = None self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None @@ -724,7 +729,8 @@ def unmap_input_ids(self): """ Disables the mapping of input ids despite the assistant pruning for the language model head being enabled. - This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. By disabling the mapping, it ensures that the input ids are processed correctly without remapping. + This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. + By disabling the mapping, it ensures that the input ids are processed correctly without remapping. """ if self.assistant_prune_lm_head: diff --git a/mindone/transformers/generation/logits_process.py b/mindone/transformers/generation/logits_process.py index 7bcc5ca7b8..88faf247f3 100644 --- a/mindone/transformers/generation/logits_process.py +++ b/mindone/transformers/generation/logits_process.py @@ -18,7 +18,7 @@ import inspect import math -from typing import Callable, Iterable, List, Optional, Union +from typing import Callable, Iterable, Optional, Union import numpy as np from transformers.utils import add_start_docstrings @@ -289,9 +289,7 @@ def __init__(self, penalty: float, prompt_ignore_length: Optional[int] = None): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") - if prompt_ignore_length is not None and ( - not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0 - ): + if prompt_ignore_length is not None and (not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0): raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}") self.penalty = penalty @@ -306,7 +304,7 @@ def set_continuous_batching_context(self, logits_indices: ms.Tensor, cumulative_ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor: if self.prompt_ignore_length: - input_ids = input_ids[:, self.prompt_ignore_length:] + input_ids = input_ids[:, self.prompt_ignore_length :] if scores.dim() == 3: if self.logits_indices is not None and self.cumulative_seqlens_q is not None: diff --git a/mindone/transformers/generation/stopping_criteria.py b/mindone/transformers/generation/stopping_criteria.py index 8489f7e5c4..2548d26c70 100644 --- a/mindone/transformers/generation/stopping_criteria.py +++ b/mindone/transformers/generation/stopping_criteria.py @@ -5,7 +5,7 @@ from abc import ABC from collections import OrderedDict from copy import deepcopy -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np from transformers import PreTrainedTokenizerBase diff --git a/mindone/transformers/generation/utils.py b/mindone/transformers/generation/utils.py index d17a899226..95382ba550 100644 --- a/mindone/transformers/generation/utils.py +++ b/mindone/transformers/generation/utils.py @@ -16,27 +16,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import copy import inspect +import os import time import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import numpy as np from huggingface_hub import file_exists from packaging import version from transformers import logging -from transformers.generation.configuration_utils import CompileConfig, GenerationConfig, GenerationMode -from transformers.tokenization_utils import ExtensionsTrie -from transformers.utils.generic import ModelOutput from transformers.dynamic_module_utils import ( check_python_requirements, get_cached_module_file, get_class_in_module, resolve_trust_remote_code, ) +from transformers.generation.configuration_utils import CompileConfig, GenerationConfig, GenerationMode +from transformers.tokenization_utils import ExtensionsTrie +from transformers.utils.generic import ModelOutput import mindspore as ms import mindspore.numpy as mnp @@ -361,6 +361,7 @@ class GenerationMixin: To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). """ + def load_custom_generate( self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, @@ -452,10 +453,7 @@ def _cache_dependant_input_preparation( # fixme there is no implementation for torch dynamo exporting if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] - elif ( - inputs_embeds is not None # Exception 1 - or (cache_position[-1] >= input_ids.shape[1]) # Exception 3 - ): + elif inputs_embeds is not None or (cache_position[-1] >= input_ids.shape[1]): # Exception 1 # Exception 3 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] @@ -1397,9 +1395,7 @@ def _get_logits_processor( # Watermarking should be after all logits processing is finished (see #34630) if generation_config.watermarking_config is not None: processors.append( - generation_config.watermarking_config.construct_processor( - self.config.get_text_config().vocab_size - ) + generation_config.watermarking_config.construct_processor(self.config.get_text_config().vocab_size) ) # `LogitNormalization` should always be the last logit processor, when present @@ -1857,7 +1853,7 @@ def _prepare_generation_config( # - otherwise: legacy behavior, let's just make sure we have the tokens defined model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version) if use_model_defaults is True or ( - use_model_defaults is None and model_base_version >= version.parse("4.50.0") + use_model_defaults is None and model_base_version >= version.parse("4.50.0") ): modified_values = {} global_default_generation_config = GenerationConfig() @@ -1869,8 +1865,8 @@ def _prepare_generation_config( global_default_value = getattr(global_default_generation_config, key, None) custom_gen_config_value = getattr(generation_config, key, None) if ( - custom_gen_config_value == global_default_value - and model_gen_config_value != global_default_value + custom_gen_config_value == global_default_value + and model_gen_config_value != global_default_value ): modified_values[key] = model_gen_config_value setattr(generation_config, key, model_gen_config_value) @@ -1986,7 +1982,7 @@ def _get_cache( } if cache_implementation in ["static", "hybrid", "offloaded_static"]: cache_kwargs.update({"tp_size": self.tp_size}) - + self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: encoder_kwargs = cache_kwargs.copy() @@ -1995,7 +1991,7 @@ def _get_cache( else: self._cache.reset() return self._cache - + def _supports_default_dynamic_cache(self) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. @@ -2074,10 +2070,10 @@ def _prepare_cache_for_generation( Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is instantiated, writes it to `model_kwargs`, under the name expected by the model. """ - - is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"]) + # fixme is_hybrid_cache is never used + # is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"]) cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" - + requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None ) @@ -2327,7 +2323,7 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge return False # Base logic - valid_hardware = ms.get_context("mode")==0 or bool( + valid_hardware = ms.get_context("mode") == 0 or bool( generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices ) using_compilable_cache = ( @@ -2867,7 +2863,7 @@ def _sample( os.environ["TOKENIZERS_PARALLELISM"] = "0" # If we use FA2 and a static cache, we cannot compile with fullgraph if self.config._attn_implementation == "flash_attention_2" and getattr( - model_kwargs.get("past_key_values"), "is_compileable", False + model_kwargs.get("past_key_values"), "is_compileable", False ): if generation_config.compile_config is None: generation_config.compile_config = CompileConfig(fullgraph=False) @@ -3603,9 +3599,7 @@ def _prefill_chunking(self, input_ids: ms.Tensor, generation_config: GenerationC # Prepare inputs if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask[:, :current_length] - model_kwargs["cache_position"] = mint.arange( - past_length, current_length, dtype=ms.int64 - ) + model_kwargs["cache_position"] = mint.arange(past_length, current_length, dtype=ms.int64) model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0) model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs) @@ -3618,4 +3612,4 @@ def _prefill_chunking(self, input_ids: ms.Tensor, generation_config: GenerationC model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 _ = model_kwargs.pop("position_ids", None) - return model_kwargs \ No newline at end of file + return model_kwargs diff --git a/mindone/transformers/modeling_layers.py b/mindone/transformers/modeling_layers.py index 1f23264b2b..bff31b14ac 100644 --- a/mindone/transformers/modeling_layers.py +++ b/mindone/transformers/modeling_layers.py @@ -17,6 +17,8 @@ from abc import ABC from typing import Optional +from transformers.utils import auto_docstring, can_return_tuple + import mindspore as ms import mindspore.nn as nn from mindspore import mint @@ -31,7 +33,6 @@ from .models.auto import AutoModel from .processing_utils import Unpack from .utils import TransformersKwargs, logging -from transformers.utils import auto_docstring, can_return_tuple logger = logging.get_logger(__name__) diff --git a/mindone/transformers/modeling_rope_utils.py b/mindone/transformers/modeling_rope_utils.py index 42587d6b05..cb9bc525be 100644 --- a/mindone/transformers/modeling_rope_utils.py +++ b/mindone/transformers/modeling_rope_utils.py @@ -177,9 +177,7 @@ def _compute_dynamic_ntk_parameters( return inv_freq, attention_factor -def _compute_yarn_parameters( - config: PretrainedConfig, seq_len: Optional[int] = None -) -> tuple["ms.Tensor", float]: +def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> tuple["ms.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Please refer to the [original paper](https://huggingface.co/papers/2309.00071) @@ -264,9 +262,7 @@ def linear_ramp_factor(min, max, dim): return inv_freq, attention_factor -def _compute_longrope_parameters( - config: PretrainedConfig, seq_len: Optional[int] = None -) -> tuple["ms.Tensor", float]: +def _compute_longrope_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> tuple["ms.Tensor", float]: """ Computes the inverse frequencies with LongRoPE scaling. Please refer to the [original implementation](https://github.com/microsoft/LongRoPE) @@ -317,9 +313,7 @@ def _compute_longrope_parameters( return inv_freq, attention_factor -def _compute_llama3_parameters( - config: PretrainedConfig, seq_len: Optional[int] = None -) -> tuple["ms.Tensor", float]: +def _compute_llama3_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> tuple["ms.Tensor", float]: """ Computes the inverse frequencies for llama 3.1. diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 9786b79cdb..5956f19270 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -25,7 +25,7 @@ import warnings from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union, get_type_hints +from typing import Any, Callable, MutableMapping, Optional, Union from transformers.configuration_utils import PretrainedConfig from transformers.dynamic_module_utils import custom_object_save @@ -79,8 +79,8 @@ prune_linear_layer, ) from .modeling_attn_mask_utils import dtype_to_min -from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available from .utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder +from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available if is_safetensors_available(): from safetensors import safe_open @@ -396,6 +396,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name + def _get_mindspore_dtype( cls, mindspore_dtype: Optional[Union[str, ms.Type, dict]], @@ -442,6 +443,7 @@ def _get_mindspore_dtype( # TODO: We cannot set default mindspore dtype! return config, mindspore_dtype + def _find_missing_and_unexpected_keys( cls, model: "PreTrainedModel", @@ -728,6 +730,7 @@ def floating_point_ops(self, input_dict: dict[str, Union[ms.Tensor, Any]], exclu return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) + class EmbeddingAccessMixin: """ Base utilities to regroup getters and setters for embeddings. @@ -817,7 +820,10 @@ def set_output_embeddings(self, new_embeddings): if getattr(self, "lm_head"): self.lm_head = new_embeddings -class PreTrainedModel(nn.Cell, EmbeddingAccessMixin, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): + +class PreTrainedModel( + nn.Cell, EmbeddingAccessMixin, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin +): r""" Base class for all models. @@ -1181,8 +1187,6 @@ def _from_config(cls, config, **kwargs): mindspore_dtype = str(mindspore_dtype) mindspore_dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[mindspore_dtype] - use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) - config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs) @@ -1232,9 +1236,11 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" ) - if not is_flash_attn_2_available(): - preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" - install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." + # fixme variable is assigned but never used + # if not is_flash_attn_2_available(): + # preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" + # install_message = "Please refer to the documentation of + # https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." if mindspore_dtype is None: logger.warning_once( @@ -1243,8 +1249,9 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: elif mindspore_dtype is not None and mindspore_dtype not in [ms.float16, ms.bfloat16]: logger.warning_once( "Flash Attention 2 only supports ms.float16 and ms.bfloat16 dtypes, but" - f" the current dype in {self.__class__.__name__} is {mindspore_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," - ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`' + f" the current dype in {self.__class__.__name__} is {mindspore_dtype}. You should run training or inference using Automatic Mixed-Precision," + ' or load the model with the `torch_dtype` argument. ' + 'Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", mindspore_dtype=ms.float16)`' ) # With the early check, the parameters are not yet initalized correctly @@ -1272,7 +1279,8 @@ def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool: raise ValueError( f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" - ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ' this error is a bug, please open an issue in Transformers GitHub repository and ' + 'load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' ) if not is_sdpa_available(): raise ImportError("MindSpore SDPA requirements in Transformers are not met.") @@ -1301,7 +1309,6 @@ def _check_and_adjust_attn_implementation( """ applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation): - # Extract repo_id and kernel_name from the string if ":" in applicable_attn_implementation: repo_id, kernel_name = attn_implementation.split(":") @@ -1365,9 +1372,7 @@ def _can_set_attn_implementation(cls) -> bool: with open(class_file, "r") as f: code = f.read() # heuristic -> if we find those patterns, the model uses the correct interface - return ( - "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code - ) + return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code def set_attn_implementation(self, attn_implementation: Union[str, dict]): """ @@ -2343,14 +2348,13 @@ def from_pretrained( subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) - use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_name = kwargs.pop("adapter_name", "default") key_mapping = kwargs.pop("key_mapping", None) # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model if key_mapping is None and any( - allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS + allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS ): key_mapping = cls._checkpoint_conversion_mapping @@ -2735,7 +2739,14 @@ def from_pretrained( # Find the correct dtype based on current state config, mindspore_dtype = _get_mindspore_dtype( - cls, mindspore_dtype, resolved_archive_file, config, sharded_metadata, state_dict, weights_only, is_sharded + cls, + mindspore_dtype, + resolved_archive_file, + config, + sharded_metadata, + state_dict, + weights_only, + is_sharded, ) # Check if `_keep_in_fp32_modules` is not None @@ -3166,9 +3177,7 @@ def get_compiled_call(self) -> Callable: # Only reset it if not present or different from previous config if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT return self.__call__ - if ( - not hasattr(self, "_compiled_call") - ): + if not hasattr(self, "_compiled_call"): self._compiled_call = ms.jit(self.__call__) return self._compiled_call @@ -3721,4 +3730,4 @@ def valid_keys(self) -> list[str]: ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() # for BC -MSPreTrainedModel = PreTrainedModel \ No newline at end of file +MSPreTrainedModel = PreTrainedModel diff --git a/mindone/transformers/models/aria/modeling_aria.py b/mindone/transformers/models/aria/modeling_aria.py index 1f0229f8ea..351df9a229 100644 --- a/mindone/transformers/models/aria/modeling_aria.py +++ b/mindone/transformers/models/aria/modeling_aria.py @@ -31,7 +31,6 @@ logging, replace_return_docstrings, ) -from ...utils import LossKwargs import mindspore as ms import mindspore.mint as mint @@ -50,6 +49,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, MSPreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs # from ..auto import AutoModelForCausalLM, AutoModel from ..idefics3 import Idefics3VisionTransformer diff --git a/mindone/transformers/models/cohere2/modeling_cohere2.py b/mindone/transformers/models/cohere2/modeling_cohere2.py index 8c7486ccde..6ad17b5261 100644 --- a/mindone/transformers/models/cohere2/modeling_cohere2.py +++ b/mindone/transformers/models/cohere2/modeling_cohere2.py @@ -31,7 +31,6 @@ logging, replace_return_docstrings, ) -from ...utils import LossKwargs from transformers.utils.deprecation import deprecate_kwarg import mindspore as ms @@ -49,6 +48,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/ernie4_5/__init__.py b/mindone/transformers/models/ernie4_5/__init__.py index ba97e39a99..b57635823a 100644 --- a/mindone/transformers/models/ernie4_5/__init__.py +++ b/mindone/transformers/models/ernie4_5/__init__.py @@ -1 +1 @@ -from .modeling_ernie4_5 import Ernie4_5PreTrainedModel, Ernie4_5Model, Ernie4_5ForCausalLM \ No newline at end of file +from .modeling_ernie4_5 import Ernie4_5ForCausalLM, Ernie4_5Model, Ernie4_5PreTrainedModel diff --git a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py index 0430438995..9e2f21fce6 100644 --- a/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/mindone/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -20,8 +20,11 @@ from typing import Callable, Optional, Union +from transformers.models.ernie4_5.configuration_ernie4_5 import Ernie4_5Config +from transformers.utils import auto_docstring, can_return_tuple + import mindspore as ms -from mindspore import mint, nn, Parameter +from mindspore import Parameter, mint, nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache @@ -33,8 +36,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs -from transformers.utils import auto_docstring, can_return_tuple -from transformers.models.ernie4_5.configuration_ernie4_5 import Ernie4_5Config class Ernie4_5RotaryEmbedding(nn.Cell): @@ -60,7 +61,7 @@ def __init__(self, config: Ernie4_5Config): def construct(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().broadcast_to((position_ids.shape[0], -1, 1)) position_ids_expanded = position_ids[:, None, :].float() - + # fixme there is not implementation for torch.autocast freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).swapaxes(1, 2) emb = mint.cat((freqs, freqs), dim=-1) @@ -183,10 +184,18 @@ def __init__(self, config: Ernie4_5Config, layer_idx: int): self.attention_dropout = 0.0 self.is_causal = True - self.q_proj = mint.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) - self.k_proj = mint.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) - self.v_proj = mint.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) - self.o_proj = mint.nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.q_proj = mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias + ) + self.k_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias + ) + self.v_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias + ) + self.o_proj = mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias + ) def construct( self, @@ -299,7 +308,7 @@ def construct( @auto_docstring class Ernie4_5PreTrainedModel(PreTrainedModel): - config : Ernie4_5Config + config: Ernie4_5Config base_model_prefix = "model" supports_gradient_checkpointing = False _no_split_modules = ["Ernie4_5DecoderLayer"] @@ -336,7 +345,7 @@ def __init__(self, config: Ernie4_5Config): # Initialize weights and apply final processing self.post_init() - + @auto_docstring def construct( self, @@ -360,9 +369,7 @@ def construct( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: ms.Tensor = mint.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] - ) + cache_position: ms.Tensor = mint.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/mindone/transformers/models/gemma/modeling_gemma.py b/mindone/transformers/models/gemma/modeling_gemma.py index 7c600a494e..12b62e8f20 100644 --- a/mindone/transformers/models/gemma/modeling_gemma.py +++ b/mindone/transformers/models/gemma/modeling_gemma.py @@ -25,7 +25,6 @@ from typing import Callable, List, Optional, Tuple, Union from transformers.models.gemma.configuration_gemma import GemmaConfig -from ...utils import LossKwargs, logging import mindspore as ms from mindspore import mint, nn, ops @@ -45,6 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs, logging logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/glm/modeling_glm.py b/mindone/transformers/models/glm/modeling_glm.py index a34009917d..1203c0687e 100644 --- a/mindone/transformers/models/glm/modeling_glm.py +++ b/mindone/transformers/models/glm/modeling_glm.py @@ -27,7 +27,6 @@ from transformers import GlmConfig from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops @@ -47,6 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs class GlmRMSNorm(nn.Cell): diff --git a/mindone/transformers/models/glm4/modeling_glm4.py b/mindone/transformers/models/glm4/modeling_glm4.py index 6e8dcf85b4..6237837040 100644 --- a/mindone/transformers/models/glm4/modeling_glm4.py +++ b/mindone/transformers/models/glm4/modeling_glm4.py @@ -25,7 +25,6 @@ from typing import Callable, Optional, Tuple, Union from transformers.models.glm4.configuration_glm4 import Glm4Config -from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops @@ -44,7 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import logging +from ...utils import LossKwargs, logging logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/glm4v/modeling_glm4v.py b/mindone/transformers/models/glm4v/modeling_glm4v.py index 16dad90618..358499ba93 100644 --- a/mindone/transformers/models/glm4v/modeling_glm4v.py +++ b/mindone/transformers/models/glm4v/modeling_glm4v.py @@ -27,7 +27,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from transformers.utils import logging -from ...utils import LossKwargs import mindspore as ms import mindspore.mint.nn.functional as F @@ -45,6 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/granite/modeling_granite.py b/mindone/transformers/models/granite/modeling_granite.py index 2eee7a918d..6139de87ea 100644 --- a/mindone/transformers/models/granite/modeling_granite.py +++ b/mindone/transformers/models/granite/modeling_granite.py @@ -31,7 +31,6 @@ logging, replace_return_docstrings, ) -from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops @@ -46,6 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "GraniteConfig" diff --git a/mindone/transformers/models/helium/modeling_helium.py b/mindone/transformers/models/helium/modeling_helium.py index 46ce52d98e..2d28a8fd3c 100644 --- a/mindone/transformers/models/helium/modeling_helium.py +++ b/mindone/transformers/models/helium/modeling_helium.py @@ -9,7 +9,6 @@ logging, replace_return_docstrings, ) -from ...utils import LossKwargs from transformers.utils.deprecation import deprecate_kwarg import mindspore as ms @@ -32,6 +31,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/llama/modeling_llama.py b/mindone/transformers/models/llama/modeling_llama.py index 1e9fe75258..51511cabb3 100644 --- a/mindone/transformers/models/llama/modeling_llama.py +++ b/mindone/transformers/models/llama/modeling_llama.py @@ -25,7 +25,6 @@ import numpy as np from transformers import LlamaConfig from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging -from ...utils import LossKwargs import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops @@ -42,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/mixtral/modeling_mixtral.py b/mindone/transformers/models/mixtral/modeling_mixtral.py index dfab7eec42..12560299d6 100644 --- a/mindone/transformers/models/mixtral/modeling_mixtral.py +++ b/mindone/transformers/models/mixtral/modeling_mixtral.py @@ -37,7 +37,6 @@ logging, replace_return_docstrings, ) -from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops @@ -60,6 +59,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/phi3/modeling_phi3.py b/mindone/transformers/models/phi3/modeling_phi3.py index 2a0e46cf22..93f17a831e 100644 --- a/mindone/transformers/models/phi3/modeling_phi3.py +++ b/mindone/transformers/models/phi3/modeling_phi3.py @@ -28,7 +28,6 @@ from transformers.models.phi3.configuration_phi3 import Phi3Config from transformers.utils import logging -from ...utils import LossKwargs import mindspore as ms from mindspore import mint, nn, ops @@ -48,6 +47,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, MSPreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py index 24f9a2cda6..b5db18b264 100644 --- a/mindone/transformers/models/qwen2/modeling_qwen2.py +++ b/mindone/transformers/models/qwen2/modeling_qwen2.py @@ -15,7 +15,6 @@ from typing import Callable, List, Optional, Tuple, Union from transformers import Qwen2Config, logging -from ...utils import LossKwargs import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops @@ -41,6 +40,8 @@ from mindone.transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, MSPreTrainedModel from mindone.transformers.processing_utils import Unpack +from ...utils import LossKwargs + logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/qwen3/modeling_qwen3.py b/mindone/transformers/models/qwen3/modeling_qwen3.py index 590bed7ec9..de6e45db4e 100644 --- a/mindone/transformers/models/qwen3/modeling_qwen3.py +++ b/mindone/transformers/models/qwen3/modeling_qwen3.py @@ -35,7 +35,6 @@ logging, replace_return_docstrings, ) -from ...utils import LossKwargs import mindspore as ms from mindspore import Tensor, mint, nn @@ -59,6 +58,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/starcoder2/modeling_starcoder2.py b/mindone/transformers/models/starcoder2/modeling_starcoder2.py index f9a753f7b1..e977fe6935 100644 --- a/mindone/transformers/models/starcoder2/modeling_starcoder2.py +++ b/mindone/transformers/models/starcoder2/modeling_starcoder2.py @@ -37,7 +37,6 @@ logging, replace_return_docstrings, ) -from ...utils import LossKwargs from transformers.utils.deprecation import deprecate_kwarg import mindspore as ms @@ -58,6 +57,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...utils import LossKwargs logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b" diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index 670454e373..aa4cf69c21 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -30,12 +30,14 @@ import numpy as np -from mindspore import Tensor +import mindspore as ms +from mindspore import nn, Tensor from .import_utils import is_mindspore_available _CAN_RECORD_REGISTRY = {} + class cached_property(property): """ Descriptor that mimics @property but caches output in member variable. @@ -446,6 +448,7 @@ def mindspore_float(x): return x.to(ms.float32) if isinstance(x, ms.Tensor) else int(x) + class TransformersKwargs(TypedDict, total=False): """ Keyword arguments to be passed to the loss function @@ -478,7 +481,7 @@ class OutputRecorder: layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn". """ - target_class: "type[ms.nn.Cell]" + target_class: "type[nn.Cell]" index: Optional[int] = 0 layer_name: Optional[str] = None From 02834b0f6006326ee65f9422cc4e11dd8914c6d1 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 25 Aug 2025 10:20:06 +0800 Subject: [PATCH 21/94] pre-commit --- mindone/transformers/cache_utils.py | 2 +- mindone/transformers/modeling_utils.py | 10 ++++++---- mindone/transformers/utils/generic.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mindone/transformers/cache_utils.py b/mindone/transformers/cache_utils.py index 8f1b02d598..f720efd99b 100644 --- a/mindone/transformers/cache_utils.py +++ b/mindone/transformers/cache_utils.py @@ -1194,7 +1194,7 @@ class SlidingWindowCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation SlidingWindowCache() diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 5956f19270..ffecab0e53 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -1250,7 +1250,7 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: logger.warning_once( "Flash Attention 2 only supports ms.float16 and ms.bfloat16 dtypes, but" f" the current dype in {self.__class__.__name__} is {mindspore_dtype}. You should run training or inference using Automatic Mixed-Precision," - ' or load the model with the `torch_dtype` argument. ' + " or load the model with the `torch_dtype` argument. " 'Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", mindspore_dtype=ms.float16)`' ) @@ -1258,7 +1258,8 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: if not is_init_check: if getattr(self, "use_bettertransformer", False): raise ValueError( - "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers " + "by doing model.reverse_bettertransformer()" ) # If no error raise by this point, we can return `True` @@ -1279,8 +1280,9 @@ def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool: raise ValueError( f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" - ' this error is a bug, please open an issue in Transformers GitHub repository and ' - 'load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + " this error is a bug, please open an issue in Transformers GitHub repository and " + 'load your model with the argument `attn_implementation="eager"` meanwhile. ' + 'Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' ) if not is_sdpa_available(): raise ImportError("MindSpore SDPA requirements in Transformers are not met.") diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index aa4cf69c21..f0e2af3973 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -31,7 +31,7 @@ import numpy as np import mindspore as ms -from mindspore import nn, Tensor +from mindspore import Tensor, nn from .import_utils import is_mindspore_available From 65e82561fa10758c6afe1b6ae122a1f2000fa02a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:27:52 +0800 Subject: [PATCH 22/94] update backbone_utils --- mindone/transformers/utils/backbone_utils.py | 36 +++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/mindone/transformers/utils/backbone_utils.py b/mindone/transformers/utils/backbone_utils.py index fa182079fe..6ac303da34 100644 --- a/mindone/transformers/utils/backbone_utils.py +++ b/mindone/transformers/utils/backbone_utils.py @@ -19,9 +19,13 @@ import enum import inspect -from typing import Iterable, List, Optional, Tuple, Union +from collections.abc import Iterable +from typing import TYPE_CHECKING, Optional, Union +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class BackboneType(enum.Enum): TIMM = "timm" TRANSFORMERS = "transformers" @@ -73,9 +77,9 @@ def verify_out_features_out_indices( def _align_output_features_output_indices( - out_features: Optional[List[str]], - out_indices: Optional[Union[List[int], Tuple[int]]], - stage_names: List[str], + out_features: Optional[list[str]], + out_indices: Optional[Union[list[int], tuple[int]]], + stage_names: list[str], ): """ Finds the corresponding `out_features` and `out_indices` for the given `stage_names`. @@ -89,9 +93,9 @@ def _align_output_features_output_indices( - `out_indices` and `out_features` set: input `out_indices` and `out_features` are returned. Args: - out_features (`List[str]`): The names of the features for the backbone to output. - out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output. - stage_names (`List[str]`): The names of the stages of the backbone. + out_features (`list[str]`): The names of the features for the backbone to output. + out_indices (`list[int]` or `tuple[int]`): The indices of the features for the backbone to output. + stage_names (`list[str]`): The names of the stages of the backbone. """ if out_indices is None and out_features is None: out_indices = [len(stage_names) - 1] @@ -104,10 +108,10 @@ def _align_output_features_output_indices( def get_aligned_output_features_output_indices( - out_features: Optional[List[str]], - out_indices: Optional[Union[List[int], Tuple[int]]], - stage_names: List[str], -) -> Tuple[List[str], List[int]]: + out_features: Optional[list[str]], + out_indices: Optional[Union[list[int], tuple[int]]], + stage_names: list[str], +) -> tuple[list[str], list[int]]: """ Get the `out_features` and `out_indices` so that they are aligned. @@ -120,9 +124,9 @@ def get_aligned_output_features_output_indices( - `out_indices` and `out_features` set: they are verified to be aligned. Args: - out_features (`List[str]`): The names of the features for the backbone to output. - out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output. - stage_names (`List[str]`): The names of the stages of the backbone. + out_features (`list[str]`): The names of the features for the backbone to output. + out_indices (`list[int]` or `tuple[int]`): The indices of the features for the backbone to output. + stage_names (`list[str]`): The names of the stages of the backbone. """ out_indices = list(out_indices) if out_indices is not None else None # First verify that the out_features and out_indices are valid @@ -175,7 +179,7 @@ def out_features(self): return self._out_features @out_features.setter - def out_features(self, out_features: List[str]): + def out_features(self, out_features: list[str]): """ Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. """ @@ -188,7 +192,7 @@ def out_indices(self): return self._out_indices @out_indices.setter - def out_indices(self, out_indices: Union[Tuple[int], List[int]]): + def out_indices(self, out_indices: Union[tuple[int], list[int]]): """ Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. """ From 913cd3c28e2591ef2b6e3516049d1ed95c83e849 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:46:34 +0800 Subject: [PATCH 23/94] update generic --- mindone/transformers/utils/generic.py | 413 +++++++++++++++++++++++++- 1 file changed, 401 insertions(+), 12 deletions(-) diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index 670454e373..d130fd9ec1 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -19,23 +19,33 @@ """ import inspect +import json +import os import tempfile import warnings -from collections import UserDict -from collections.abc import MutableMapping +from collections import OrderedDict, UserDict, defaultdict +from collections.abc import MutableMapping from contextlib import ExitStack, contextmanager +from dataclasses import dataclass, fields, is_dataclass from enum import Enum from functools import wraps -from typing import Callable, ContextManager, List, Optional, TypedDict +from typing import Any, Callable, ContextManager, Optional, TypedDict import numpy as np -from mindspore import Tensor +import logging from .import_utils import is_mindspore_available +if is_mindspore_available(): + import mindspore # noqa: F401 + _CAN_RECORD_REGISTRY = {} + +logger = logging.get_logger(__name__) + + class cached_property(property): """ Descriptor that mimics @property but caches output in member variable. @@ -169,10 +179,17 @@ def to_py_obj(obj): "ms": lambda obj: obj.tolist(), "np": lambda obj: obj.tolist(), } - - if isinstance(obj, (dict, UserDict)): + if isinstance(obj, (int, float)): + return obj + elif isinstance(obj, (dict, UserDict)): return {k: to_py_obj(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): + try: + arr = np.array(obj) + if np.issubdtype(arr.dtype, np.integer) or np.issubdtype(arr.dtype, np.floating): + return arr.tolist() + except Exception: + pass return [to_py_obj(o) for o in obj] # This gives us a smart order to test the frameworks with the corresponding tests. @@ -212,6 +229,143 @@ def to_numpy(obj): return obj +class ModelOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __init_subclass__(cls) -> None: + """No need to register subclasses as pytree nodes, mindspore does not support pytree. + """ + pass + + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Subclasses of ModelOutput must use the @dataclass decorator + # This check is done in __init__ because the @dataclass decorator operates after __init_subclass__ + # issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed + # Just need to check that the current class is not ModelOutput + is_modeloutput_subclass = self.__class__ != ModelOutput + + if is_modeloutput_subclass and not is_dataclass(self): + raise TypeError( + f"{self.__module__}.{self.__class__.__name__} is not a dataclass." + " This is a subclass of ModelOutput and so must use the @dataclass decorator." + ) + + def __post_init__(self): + """Check the ModelOutput dataclass. + + Only occurs if @dataclass decorator has been used. + """ + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + if not all(field.default is None for field in class_fields[1:]): + raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and not is_tensor(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for idx, element in enumerate(iterator): + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + if idx == 0: + # If we do not have an iterator of key/values, set it as attribute + self[class_fields[0].name] = first_field + else: + # If we have a mixed iterator, raise an error + raise ValueError( + f"Cannot set key/value for {element}. It needs to be a tuple (key, value)." + ) + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def __reduce__(self): + if not is_dataclass(self): + return super().__reduce__() + callable, _args, *remaining = super().__reduce__() + args = tuple(getattr(self, field.name) for field in fields(self)) + return callable, args, *remaining + + def to_tuple(self) -> tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) + + + class ExplicitEnum(str, Enum): """ Enum with more explicit error message for missing values. @@ -251,7 +405,7 @@ class ContextManagers: in the `fastcore` library. """ - def __init__(self, context_managers: List[ContextManager]): + def __init__(self, context_managers: list[ContextManager]): self.context_managers = context_managers self.stack = ExitStack() @@ -381,7 +535,7 @@ def tensor_size(array): else: raise ValueError(f"Type not supported for tensor_size: {type(array)}.") - +#TODO: remove this function in v4.54.1 def add_model_info_to_auto_map(auto_map, repo_id): """ Adds the information of the repo_id to a given auto map. @@ -394,7 +548,7 @@ def add_model_info_to_auto_map(auto_map, repo_id): return auto_map - +#TODO: remove this function in v4.54.1 def add_model_info_to_custom_pipelines(custom_pipeline, repo_id): """ Adds the information of the repo_id to a given custom pipeline. @@ -561,6 +715,241 @@ def wrapper(*args, **kwargs): return decorator +class TransformersKwargs(TypedDict, total=False): + """ + Keyword arguments to be passed to the loss function + + Attributes: + num_items_in_batch (`Optional[mindspore.Tensor]`, *optional*): + Number of items in the batch. It is recommended to pass it when + you are doing gradient accumulation. + output_hidden_states (`Optional[bool]`, *optional*): + Most of the models support outputing all hidden states computed during the forward pass. + output_attentions (`Optional[bool]`, *optional*): + Turn this on to return the intermediary attention scores. + output_router_logits (`Optional[bool]`, *optional*): + For MoE models, this allows returning the router logits to compute the loss. + cumulative_seqlens_q (`mindspore.Tensor`, *optional*) + Gets cumulative sequence length for query state. + cumulative_seqlens_k (`mindspore.Tensor`, *optional*) + Gets cumulative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + num_items_in_batch: Optional["mindspore.Tensor"] + output_hidden_states: Optional[bool] + output_attentions: Optional[bool] + output_router_logits: Optional[bool] + cumulative_seqlens_q: Optional["mindspore.Tensor"] + cumulative_seqlens_k: Optional["mindspore.Tensor"] + max_length_q: Optional[int] + max_length_k: Optional[int] + +def is_timm_config_dict(config_dict: dict[str, Any]) -> bool: + """Checks whether a config dict is a timm config dict.""" + return "pretrained_cfg" in config_dict + + +def is_timm_local_checkpoint(pretrained_model_path: str) -> bool: + """ + Checks whether a checkpoint is a timm model checkpoint. + """ + if pretrained_model_path is None: + return False + + # in case it's Path, not str + pretrained_model_path = str(pretrained_model_path) + + is_file = os.path.isfile(pretrained_model_path) + is_dir = os.path.isdir(pretrained_model_path) + + # pretrained_model_path is a file + if is_file and pretrained_model_path.endswith(".json"): + with open(pretrained_model_path) as f: + config_dict = json.load(f) + return is_timm_config_dict(config_dict) + + # pretrained_model_path is a directory with a config.json + if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")): + with open(os.path.join(pretrained_model_path, "config.json")) as f: + config_dict = json.load(f) + return is_timm_config_dict(config_dict) + + return False + + +def set_attribute_for_modules(module: "mindspore.nn.Cell", key: str, value: Any): + """ + Set a value to a module and all submodules. + """ + setattr(module, key, value) + for submodule in module.children(): + set_attribute_for_modules(submodule, key, value) + + +def del_attribute_from_modules(module: "mindspore.nn.Cell", key: str): + """ + Delete a value from a module and all submodules. + """ + # because we might remove it previously in case it's a shared module, e.g. activation function + if hasattr(module, key): + delattr(module, key) + + for submodule in module.children(): + del_attribute_from_modules(submodule, key) + + +def can_return_tuple(func): + """ + Decorator to wrap model method, to call output.to_tuple() if return_dict=False passed as a kwarg or + use_return_dict=False is set in the config. + + Note: + output.to_tuple() convert output to tuple skipping all `None` values. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + return_dict = self.config.return_dict if hasattr(self, "config") else True + return_dict_passed = kwargs.pop("return_dict", return_dict) + if return_dict_passed is not None: + return_dict = return_dict_passed + output = func(self, *args, **kwargs) + if not return_dict and not isinstance(output, tuple): + output = output.to_tuple() + return output + + return wrapper + + + +@dataclass +class OutputRecorder: + """ + Configuration for recording outputs from a model via hooks. + + Attributes: + target_class (Type): The class (e.g., nn.Module) to which the hook will be attached. + index (Optional[int]): If the output is a tuple/list, optionally record only at a specific index. + layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn". + """ + + target_class: "type[mindspore.nn.Cell]" + index: Optional[int] = 0 + layer_name: Optional[str] = None + + +def check_model_inputs(func): + """ + Decorator to intercept specific layer outputs without using hooks. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + use_cache = kwargs.get("use_cache", None) + if use_cache is None: + use_cache = getattr(self.config, "use_cache", False) + + return_dict = kwargs.pop("return_dict", None) + if return_dict is None: + return_dict = getattr(self.config, "return_dict", True) + + if getattr(self, "gradient_checkpointing", False) and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + kwargs["use_cache"] = use_cache + + all_args = kwargs.copy() + if "kwargs" in all_args: + for k, v in all_args["kwargs"].items(): + all_args[k] = v + + capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__), {}) + recordable_keys = { + f"output_{k}": all_args.get( + f"output_{k}", + getattr( + self.config, + f"output_{k}", + all_args.get("output_attentions", getattr(self.config, "output_attentions", False)), + ), + ) + for k in capture_flags + } + collected_outputs = defaultdict(tuple) + monkey_patched_layers = [] + + def make_capture_wrapper(module, orig_forward, key, index): + @wraps(orig_forward) + def wrapped_forward(*args, **kwargs): + if key == "hidden_states" and len(collected_outputs[key]) == 0: + collected_outputs[key] += (args[0],) + output = orig_forward(*args, **kwargs) + if not isinstance(output, tuple): + collected_outputs[key] += (output,) + elif output[index] is not None: + collected_outputs[key] += (output[index],) + return output + + return wrapped_forward + + if any(recordable_keys.values()): + capture_tasks = [] + for key, layer_specs in capture_flags.items(): + if not recordable_keys.get(f"output_{key}", False): + continue + if not isinstance(layer_specs, list): + layer_specs = [layer_specs] + for specs in layer_specs: + if not isinstance(specs, OutputRecorder): + index = 0 if "hidden_states" in key else 1 + specs = OutputRecorder(target_class=specs, index=index) + capture_tasks.append((key, specs)) + + for name, module in self.named_modules(): + for key, specs in capture_tasks: + if isinstance(module, specs.target_class): + if specs.layer_name is not None and specs.layer_name not in name: + continue + # Monkey patch forward + original_forward = module.forward + module.forward = make_capture_wrapper(module, original_forward, key, specs.index) + monkey_patched_layers.append((module, original_forward)) + + outputs = func(self, *args, **kwargs) + # Restore original forward methods + for module, original_forward in monkey_patched_layers: + module.forward = original_forward + + # Inject collected outputs into model output + for key in collected_outputs: + if key == "hidden_states": + collected_outputs[key] = collected_outputs[key][:-1] + if hasattr(outputs, "vision_hidden_states"): + collected_outputs[key] += (outputs.vision_hidden_states,) + elif hasattr(outputs, "last_hidden_state"): + collected_outputs[key] += (outputs.last_hidden_state,) + + outputs[key] = collected_outputs[key] + elif key == "attentions": + if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2: + outputs[key] = collected_outputs[key][0::2] + outputs["cross_" + key] = collected_outputs[key][1::2] + else: + outputs[key] = collected_outputs[key] + else: + outputs[key] = collected_outputs[key] + if return_dict is False: + outputs = outputs.to_tuple() + return outputs + + return wrapper class GeneralInterface(MutableMapping): """ @@ -599,10 +988,10 @@ def __len__(self): def register(cls, key: str, value: Callable): cls._global_mapping.update({key: value}) - def valid_keys(self) -> List[str]: + def valid_keys(self) -> list[str]: return list(self.keys()) - +# TODO: remove this class in v4.54.1 class LossKwargs(TypedDict, total=False): """ Keyword arguments to be passed to the loss function @@ -613,4 +1002,4 @@ class LossKwargs(TypedDict, total=False): you are doing gradient accumulation. """ - num_items_in_batch: Optional[Tensor] + num_items_in_batch: Optional[mindspore.Tensor] From a0c9dc98d9ff0c9a3082ff185e2e50138369ed2f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 17:04:15 +0800 Subject: [PATCH 24/94] remove add_model_info_to_auto_map & update feature_extraction_utils.py --- .../transformers/feature_extraction_utils.py | 84 ++++++------------- mindone/transformers/utils/generic.py | 31 ++----- 2 files changed, 35 insertions(+), 80 deletions(-) diff --git a/mindone/transformers/feature_extraction_utils.py b/mindone/transformers/feature_extraction_utils.py index f93e8918e9..eff1cec771 100644 --- a/mindone/transformers/feature_extraction_utils.py +++ b/mindone/transformers/feature_extraction_utils.py @@ -23,7 +23,7 @@ import os import warnings from collections import UserDict -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union import numpy as np from transformers.dynamic_module_utils import custom_object_save @@ -40,15 +40,7 @@ logging, ) -from .utils import ( - TensorType, - add_model_info_to_auto_map, - add_model_info_to_custom_pipelines, - is_mindspore_available, - is_mindspore_tensor, - is_numpy_array, - requires_backends, -) +from .utils import TensorType, is_mindspore_available, is_mindspore_tensor, is_numpy_array, requires_backends if TYPE_CHECKING: if is_mindspore_available(): @@ -59,6 +51,9 @@ PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821 +# type hinting: specifying the type of feature extractor class that inherits from FeatureExtractionMixin +SpecificFeatureExtractorType = TypeVar("SpecificFeatureExtractorType", bound="FeatureExtractionMixin") + class BatchFeature(UserDict): r""" @@ -75,7 +70,7 @@ class BatchFeature(UserDict): initialization. """ - def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): + def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): super().__init__(data) self.convert_to_tensors(tensor_type=tensor_type) @@ -102,18 +97,6 @@ def __setstate__(self, state): if "data" in state: self.data = state["data"] - # Copied from transformers.tokenization_utils_base.BatchEncoding.keys - def keys(self): - return self.data.keys() - - # Copied from transformers.tokenization_utils_base.BatchEncoding.values - def values(self): - return self.data.values() - - # Copied from transformers.tokenization_utils_base.BatchEncoding.items - def items(self): - return self.data.items() - def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None): if tensor_type is None: return None, None @@ -191,7 +174,7 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non def to(self, *args, **kwargs) -> "BatchFeature": """ Args: - args (`Tuple`): + args (`tuple`): Will be passed to the `to(...)` function of the tensors. kwargs (`Dict`, *optional*): Will be passed to the `to(...)` function of the tensors. @@ -243,7 +226,7 @@ def _set_processor_class(self, processor_class: str): @classmethod def from_pretrained( - cls, + cls: type[SpecificFeatureExtractorType], pretrained_model_name_or_path: Union[str, os.PathLike], cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, @@ -251,7 +234,7 @@ def from_pretrained( token: Optional[Union[str, bool]] = None, revision: str = "main", **kwargs, - ): + ) -> SpecificFeatureExtractorType: r""" Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a derived class of [`SequenceFeatureExtractor`]. @@ -276,12 +259,12 @@ def from_pretrained( resume_download: Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers. - proxies (`Dict[str, str]`, *optional*): + proxies (`dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use - the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + the token generated when running `hf auth login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any @@ -296,10 +279,10 @@ def from_pretrained( return_unused_kwargs (`bool`, *optional*, defaults to `False`): If `False`, then this function returns just the final feature extractor object. If `True`, then this - functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary + functions returns a `tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. - kwargs (`Dict[str, Any]`, *optional*): + kwargs (`dict[str, Any]`, *optional*): The values in kwargs of any keys which are feature extractor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is controlled by the `return_unused_kwargs` keyword parameter. @@ -365,7 +348,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - kwargs (`Dict[str, Any]`, *optional*): + kwargs (`dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ use_auth_token = kwargs.pop("use_auth_token", None) @@ -417,7 +400,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: @classmethod def get_feature_extractor_dict( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ) -> tuple[dict[str, Any], dict[str, Any]]: """ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`. @@ -427,7 +410,7 @@ def get_feature_extractor_dict( The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. Returns: - `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object. + `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object. """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -488,13 +471,13 @@ def get_feature_extractor_dict( user_agent=user_agent, revision=revision, ) - except EnvironmentError: + except OSError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. raise except Exception: # For any other exception, we throw a generic error. - raise EnvironmentError( + raise OSError( f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load" " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" @@ -503,12 +486,12 @@ def get_feature_extractor_dict( try: # Load feature_extractor dict - with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader: + with open(resolved_feature_extractor_file, encoding="utf-8") as reader: text = reader.read() feature_extractor_dict = json.loads(text) except json.JSONDecodeError: - raise EnvironmentError( + raise OSError( f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file." ) @@ -519,30 +502,20 @@ def get_feature_extractor_dict( f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" ) - if not is_local: - if "auto_map" in feature_extractor_dict: - feature_extractor_dict["auto_map"] = add_model_info_to_auto_map( - feature_extractor_dict["auto_map"], pretrained_model_name_or_path - ) - if "custom_pipelines" in feature_extractor_dict: - feature_extractor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( - feature_extractor_dict["custom_pipelines"], pretrained_model_name_or_path - ) - return feature_extractor_dict, kwargs @classmethod - def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor: + def from_dict(cls, feature_extractor_dict: dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor: """ Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of parameters. Args: - feature_extractor_dict (`Dict[str, Any]`): + feature_extractor_dict (`dict[str, Any]`): Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be retrieved from a pretrained checkpoint by leveraging the [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method. - kwargs (`Dict[str, Any]`): + kwargs (`dict[str, Any]`): Additional parameters from which to initialize the feature extractor object. Returns: @@ -568,10 +541,10 @@ def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrain else: return feature_extractor - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ output = copy.deepcopy(self.__dict__) output["feature_extractor_type"] = self.__class__.__name__ @@ -595,7 +568,7 @@ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeature A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor object instantiated from that JSON file. """ - with open(json_file, "r", encoding="utf-8") as reader: + with open(json_file, encoding="utf-8") as reader: text = reader.read() feature_extractor_dict = json.loads(text) return cls(**feature_extractor_dict) @@ -641,11 +614,6 @@ def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"): Register this class with a given auto class. This should only be used for custom feature extractors as the ones in the library are already mapped with `AutoFeatureExtractor`. - - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`): diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index d130fd9ec1..ed16b3095c 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -20,11 +20,12 @@ import inspect import json +import logging import os import tempfile import warnings from collections import OrderedDict, UserDict, defaultdict -from collections.abc import MutableMapping +from collections.abc import MutableMapping from contextlib import ExitStack, contextmanager from dataclasses import dataclass, fields, is_dataclass from enum import Enum @@ -33,8 +34,6 @@ import numpy as np -import logging - from .import_utils import is_mindspore_available if is_mindspore_available(): @@ -244,11 +243,9 @@ class ModelOutput(OrderedDict): """ def __init_subclass__(cls) -> None: - """No need to register subclasses as pytree nodes, mindspore does not support pytree. - """ + """No need to register subclasses as pytree nodes, mindspore does not support pytree.""" pass - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -365,7 +362,6 @@ def to_tuple(self) -> tuple[Any]: return tuple(self[k] for k in self.keys()) - class ExplicitEnum(str, Enum): """ Enum with more explicit error message for missing values. @@ -535,20 +531,8 @@ def tensor_size(array): else: raise ValueError(f"Type not supported for tensor_size: {type(array)}.") -#TODO: remove this function in v4.54.1 -def add_model_info_to_auto_map(auto_map, repo_id): - """ - Adds the information of the repo_id to a given auto map. - """ - for key, value in auto_map.items(): - if isinstance(value, (tuple, list)): - auto_map[key] = [f"{repo_id}--{v}" if (v is not None and "--" not in v) else v for v in value] - elif value is not None and "--" not in value: - auto_map[key] = f"{repo_id}--{value}" - - return auto_map -#TODO: remove this function in v4.54.1 +# TODO: remove this function in v4.54.1 def add_model_info_to_custom_pipelines(custom_pipeline, repo_id): """ Adds the information of the repo_id to a given custom pipeline. @@ -715,6 +699,7 @@ def wrapper(*args, **kwargs): return decorator + class TransformersKwargs(TypedDict, total=False): """ Keyword arguments to be passed to the loss function @@ -748,6 +733,7 @@ class TransformersKwargs(TypedDict, total=False): max_length_q: Optional[int] max_length_k: Optional[int] + def is_timm_config_dict(config_dict: dict[str, Any]) -> bool: """Checks whether a config dict is a timm config dict.""" return "pretrained_cfg" in config_dict @@ -825,7 +811,6 @@ def wrapper(self, *args, **kwargs): return wrapper - @dataclass class OutputRecorder: """ @@ -870,7 +855,7 @@ def wrapper(self, *args, **kwargs): for k, v in all_args["kwargs"].items(): all_args[k] = v - capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__), {}) + capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__), {}) recordable_keys = { f"output_{k}": all_args.get( f"output_{k}", @@ -951,6 +936,7 @@ def wrapped_forward(*args, **kwargs): return wrapper + class GeneralInterface(MutableMapping): """ Dict-like object keeping track of a class-wide mapping, as well as a local one. Allows to have library-wide @@ -991,6 +977,7 @@ def register(cls, key: str, value: Callable): def valid_keys(self) -> list[str]: return list(self.keys()) + # TODO: remove this class in v4.54.1 class LossKwargs(TypedDict, total=False): """ From 2a517d8ddfe72a0cf1264d67c1a3603dacc52b27 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 17:10:58 +0800 Subject: [PATCH 25/94] remove add_model_info_to_auto_map & update image_processing_base.py --- mindone/transformers/image_processing_base.py | 28 ++++++------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/mindone/transformers/image_processing_base.py b/mindone/transformers/image_processing_base.py index 18638c8a77..d4dfd78236 100644 --- a/mindone/transformers/image_processing_base.py +++ b/mindone/transformers/image_processing_base.py @@ -40,7 +40,7 @@ ) from .feature_extraction_utils import BatchFeature as BaseBatchFeature -from .utils import add_model_info_to_auto_map, add_model_info_to_custom_pipelines, is_vision_available +from .utils import is_vision_available if is_vision_available(): from PIL import Image @@ -136,7 +136,7 @@ def from_pretrained( 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use - the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + the token generated when running `hf auth login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any @@ -352,13 +352,13 @@ def get_image_processor_dict( revision=revision, subfolder=subfolder, ) - except EnvironmentError: + except OSError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. raise except Exception: # For any other exception, we throw a generic error. - raise EnvironmentError( + raise OSError( f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load" " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" @@ -367,12 +367,12 @@ def get_image_processor_dict( try: # Load image_processor dict - with open(resolved_image_processor_file, "r", encoding="utf-8") as reader: + with open(resolved_image_processor_file, encoding="utf-8") as reader: text = reader.read() image_processor_dict = json.loads(text) except json.JSONDecodeError: - raise EnvironmentError( + raise OSError( f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file." ) @@ -382,14 +382,7 @@ def get_image_processor_dict( logger.info( f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}" ) - if "auto_map" in image_processor_dict: - image_processor_dict["auto_map"] = add_model_info_to_auto_map( - image_processor_dict["auto_map"], pretrained_model_name_or_path - ) - if "custom_pipelines" in image_processor_dict: - image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( - image_processor_dict["custom_pipelines"], pretrained_model_name_or_path - ) + return image_processor_dict, kwargs @@ -464,7 +457,7 @@ def from_json_file(cls, json_file: Union[str, os.PathLike]): A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object instantiated from that JSON file. """ - with open(json_file, "r", encoding="utf-8") as reader: + with open(json_file, encoding="utf-8") as reader: text = reader.read() image_processor_dict = json.loads(text) return cls(**image_processor_dict) @@ -510,11 +503,6 @@ def register_for_auto_class(cls, auto_class="AutoImageProcessor"): Register this class with a given auto class. This should only be used for custom image processors as the ones in the library are already mapped with `AutoImageProcessor `. - - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`): From 8d297225548bba4b4282a2723c7f45b14cfa852a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:13:48 +0800 Subject: [PATCH 26/94] remove add_model_info_to_auto_map & update processing_utils.py --- mindone/transformers/processing_utils.py | 925 ++++++++++++++--------- 1 file changed, 553 insertions(+), 372 deletions(-) diff --git a/mindone/transformers/processing_utils.py b/mindone/transformers/processing_utils.py index bb340c9086..170a83c902 100644 --- a/mindone/transformers/processing_utils.py +++ b/mindone/transformers/processing_utils.py @@ -17,7 +17,7 @@ """ Processing saving/loading class for common processors. """ - +import bisect import copy import importlib import inspect @@ -26,23 +26,20 @@ import sys import typing import warnings +from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union +from typing import Any, Optional, TypedDict, TypeVar, Union import numpy as np import typing_extensions +from huggingface_hub.errors import EntryNotFoundError from transformers.dynamic_module_utils import custom_object_save -from .image_utils import ( - ChannelDimension, - ImageInput, - VideoInput, - is_valid_image, - is_vision_available, - load_image, - load_video, -) - +from .audio_utils import load_audio +from .feature_extraction_utils import BatchFeature +from .image_utils import ChannelDimension, is_vision_available, load_image +from transformers.utils.chat_template_utils import render_jinja_template +from .video_utils import VideoMetadata, load_video if is_vision_available(): from .image_utils import PILImageResampling @@ -57,6 +54,10 @@ # fixme from transformers.utils import ( + AUDIO_TOKENIZER_NAME, + CHAT_TEMPLATE_DIR, + CHAT_TEMPLATE_FILE, + LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE, PROCESSOR_NAME, PushToHubMixin, cached_file, @@ -64,13 +65,18 @@ download_url, is_offline_mode, is_remote_url, + list_repo_templates, logging, ) -from .utils import CHAT_TEMPLATE_NAME, TensorType, add_model_info_to_auto_map, add_model_info_to_custom_pipelines +from .utils import TensorType +from transformers.utils.deprecation import deprecate_kwarg logger = logging.get_logger(__name__) +# type hinting: specifying the type of processor class that inherits from ProcessorMixin +SpecificProcessorType = TypeVar("SpecificProcessorType", bound="ProcessorMixin") + # Dynamically import the Transformers module to grab the attribute classes of the processor form their names. # transformers_module = direct_transformers_import(Path(__file__).parent) transformers_module = transformers @@ -80,6 +86,7 @@ "AutoTokenizer": "PreTrainedTokenizerBase", "AutoFeatureExtractor": "FeatureExtractionMixin", "AutoImageProcessor": "ImageProcessingMixin", + "AutoVideoProcessor": "BaseVideoProcessor", } if sys.version_info >= (3, 11): @@ -124,11 +131,13 @@ class TextKwargs(TypedDict, total=False): Whether or not to print more information and warnings. padding_side (`str`, *optional*): The side on which padding will be applied. + return_mm_token_type_ids (`bool`, *optional*): + Whether to return multimodal token type ids indicating mm placeholder token positions. """ - text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] - text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] - text_pair_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] + text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] + text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] + text_pair_target: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] add_special_tokens: Optional[bool] padding: Union[bool, str, PaddingStrategy] truncation: Union[bool, str, TruncationStrategy] @@ -144,6 +153,7 @@ class TextKwargs(TypedDict, total=False): return_length: Optional[bool] verbose: Optional[bool] padding_side: Optional[str] + return_mm_token_type_ids: Optional[bool] class ImagesKwargs(TypedDict, total=False): @@ -154,11 +164,11 @@ class methods and docstrings. Attributes: do_resize (`bool`, *optional*): Whether to resize the image. - size (`Dict[str, int]`, *optional*): + size (`dict[str, int]`, *optional*): Resize the shorter side of the input to `size["shortest_edge"]`. size_divisor (`int`, *optional*): The size by which to make sure both the height and width can be divided. - crop_size (`Dict[str, int]`, *optional*): + crop_size (`dict[str, int]`, *optional*): Desired output size when applying center-cropping. resample (`PILImageResampling`, *optional*): Resampling filter to use if resizing the image. @@ -168,13 +178,13 @@ class methods and docstrings. Scale factor to use if rescaling the image. do_normalize (`bool`, *optional*): Whether to normalize the image. - image_mean (`float` or `List[float]`, *optional*): + image_mean (`float` or `list[float]`, *optional*): Mean to use if normalizing the image. - image_std (`float` or `List[float]`, *optional*): + image_std (`float` or `list[float]`, *optional*): Standard deviation to use if normalizing the image. do_pad (`bool`, *optional*): Whether to pad the image to the `(max_height, max_width)` of the images in the batch. - pad_size (`Dict[str, int]`, *optional*): + pad_size (`dict[str, int]`, *optional*): The size `{"height": int, "width" int}` to pad the images to. do_center_crop (`bool`, *optional*): Whether to center crop the image. @@ -185,17 +195,17 @@ class methods and docstrings. """ do_resize: Optional[bool] - size: Optional[Dict[str, int]] + size: Optional[dict[str, int]] size_divisor: Optional[int] - crop_size: Optional[Dict[str, int]] + crop_size: Optional[dict[str, int]] resample: Optional[Union["PILImageResampling", int]] do_rescale: Optional[bool] rescale_factor: Optional[float] do_normalize: Optional[bool] - image_mean: Optional[Union[float, List[float]]] - image_std: Optional[Union[float, List[float]]] + image_mean: Optional[Union[float, list[float]]] + image_std: Optional[Union[float, list[float]]] do_pad: Optional[bool] - pad_size: Optional[Dict[str, int]] + pad_size: Optional[dict[str, int]] do_center_crop: Optional[bool] data_format: Optional[ChannelDimension] input_data_format: Optional[Union[str, ChannelDimension]] @@ -206,47 +216,69 @@ class VideosKwargs(TypedDict, total=False): Keyword arguments for video processing. Attributes: + do_convert_rgb (`bool`): + Whether to convert the video to RGB fromat. do_resize (`bool`): - Whether to resize the image. - size (`Dict[str, int]`, *optional*): + Whether to resize the video. + size (`dict[str, int]`, *optional*): Resize the shorter side of the input to `size["shortest_edge"]`. + default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): + Whether to default to a square when resizing, if size is an int. size_divisor (`int`, *optional*): The size by which to make sure both the height and width can be divided. resample (`PILImageResampling`, *optional*): - Resampling filter to use if resizing the image. + Resampling filter to use if resizing the video. do_rescale (`bool`, *optional*): - Whether to rescale the image by the specified scale `rescale_factor`. + Whether to rescale the video by the specified scale `rescale_factor`. rescale_factor (`int` or `float`, *optional*): - Scale factor to use if rescaling the image. + Scale factor to use if rescaling the video. do_normalize (`bool`, *optional*): - Whether to normalize the image. - image_mean (`float` or `List[float]`, *optional*): - Mean to use if normalizing the image. - image_std (`float` or `List[float]`, *optional*): - Standard deviation to use if normalizing the image. + Whether to normalize the video. + image_mean (`float` or `list[float]`, *optional*): + Mean to use if normalizing the video. + image_std (`float` or `list[float]`, *optional*): + Standard deviation to use if normalizing the video. do_pad (`bool`, *optional*): - Whether to pad the image to the `(max_height, max_width)` of the images in the batch. + Whether to pad the video to the `(max_height, max_width)` of the videos in the batch. do_center_crop (`bool`, *optional*): - Whether to center crop the image. + Whether to center crop the video. + do_sample_frames (`bool`, *optional*): + Whether to sample frames from the video before processing or to process the whole video. + video_metadata (`VideoMetadata`, *optional*): + Metadata of the video containing information about total duration, fps and total number of frames. + num_frames (`int`, *optional*): + Maximum number of frames to sample when `do_sample_frames=True`. + fps (`int` or `float`, *optional*): + Target frames to sample per second when `do_sample_frames=True`. + crop_size (`dict[str, int]`, *optional*): + Desired output size when applying center-cropping. data_format (`ChannelDimension` or `str`, *optional*): - The channel dimension format for the output image. + The channel dimension format for the output video. input_data_format (`ChannelDimension` or `str`, *optional*): - The channel dimension format for the input image. + The channel dimension format for the input video. """ + do_convert_rgb: Optional[bool] do_resize: Optional[bool] - size: Optional[Dict[str, int]] + size: Optional[dict[str, int]] size_divisor: Optional[int] + default_to_square: Optional[bool] resample: Optional["PILImageResampling"] do_rescale: Optional[bool] rescale_factor: Optional[float] do_normalize: Optional[bool] - image_mean: Optional[Union[float, List[float]]] - image_std: Optional[Union[float, List[float]]] + image_mean: Optional[Union[float, list[float]]] + image_std: Optional[Union[float, list[float]]] do_pad: Optional[bool] do_center_crop: Optional[bool] + crop_size: Optional[dict[str, int]] data_format: Optional[ChannelDimension] input_data_format: Optional[Union[str, ChannelDimension]] + device: Optional[str] + do_sample_frames: Optional[bool] + video_metadata: Optional[Union[VideoMetadata, dict]] + fps: Optional[Union[int, float]] + num_frames: Optional[int] class AudioKwargs(TypedDict, total=False): @@ -256,7 +288,7 @@ class AudioKwargs(TypedDict, total=False): Attributes: sampling_rate (`int`, *optional*): The sampling rate at which the `raw_speech` input was sampled. - raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not stereo, i.e. single float per timestep. @@ -280,7 +312,7 @@ class AudioKwargs(TypedDict, total=False): """ sampling_rate: Optional[int] - raw_speech: Optional[Union["np.ndarray", List[float], List["np.ndarray"], List[List[float]]]] + raw_speech: Optional[Union["np.ndarray", list[float], list["np.ndarray"], list[list[float]]]] padding: Optional[Union[bool, str, PaddingStrategy]] max_length: Optional[int] truncation: Optional[bool] @@ -351,13 +383,13 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False): """ Keyword arguments for tokenizer's `apply_chat_template`, when it is called from within a processor. - tools (`List[Dict]`, *optional*): + tools (`list[dict]`, *optional*): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, giving the name, description and argument types for the tool. See our [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) for more information. - documents (`List[Dict[str, str]]`, *optional*): + documents (`list[dict[str, str]]`, *optional*): A list of dicts representing documents that will be accessible to the model if it is performing RAG (retrieval-augmented generation). If the template does not support RAG, this argument will have no effect. We recommend that each document should be a dict containing "title" and "text" keys. Please @@ -379,30 +411,22 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False): This functionality is only available for chat templates that support it via the `{% generation %}` keyword. """ - tools: Optional[List[Dict]] = None - documents: Optional[List[Dict[str, str]]] = None + tools: Optional[list[dict]] = None + documents: Optional[list[dict[str, str]]] = None add_generation_prompt: Optional[bool] = False continue_final_message: Optional[bool] = False return_assistant_tokens_mask: Optional[bool] = False - -class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False): +class ChatTemplateLoadKwargs(TypedDict, total=False): """ - Keyword arguments for processor chat templates. + Keyword arguments used to load multimodal data in processor chat templates. - tokenize (`bool`, *optional*, defaults to `False`): - Whether to tokenize the output or not. - return_dict (`bool`, defaults to `False`): - Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. num_frames (`int`, *optional*): Number of frames to sample uniformly. If not passed, the whole video is loaded. video_load_backend (`str`, *optional*, defaults to `"pyav"`): The backend to use when loading the video which will be used only when there are videos in the conversation. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend that supports all types of sources to load from. - video_fps (`int`, *optional*): - Number of frames to sample per second. Should be passed only when `num_frames=None`. - If not specified and `num_frames==None`, all frames are sampled. sample_indices_fn (`Callable`, *optional*): A callable function that will return indices at which the video should be sampled. If the video has to be loaded using by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. @@ -415,18 +439,61 @@ def sample_indices_fn(num_frames, fps, metadata, **kwargs): return np.linspace(start_idx, end_idx, num_frames, dtype=int) """ - tokenize: Optional[bool] = False - return_dict: Optional[bool] = False - num_frames: Optional[int] = None video_load_backend: Optional[str] = "pyav" - video_fps: Optional[int] = None - sample_indices_fn: Optional[Callable] = None + sampling_rate: Optional[int] = 16_000 + load_audio_from_video: Optional[bool] = False +class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False): + """ + Keyword arguments for processor's `apply_chat_template`. + + tokenize (`bool`, *optional*, defaults to `False`): + Whether to tokenize the output or not. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. + """ + + tokenize: Optional[bool] = False + return_dict: Optional[bool] = False class AllKwargsForChatTemplate( TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs ): - ... + processor_kwargs: ProcessingKwargs = { + **ProcessingKwargs.__annotations__, + } + mm_load_kwargs: ChatTemplateLoadKwargs = { + **TextKwargs.__annotations__, + } + template_kwargs: ProcessorChatTemplateKwargs = { + **ProcessorChatTemplateKwargs.__annotations__, + } + + +@dataclass +class MultiModalData: + """ + Dataclass that holds extra useful data for processing + multimodal data. Processors currently cannot return keys, + unless it is used in model's forward. Thus we have helper + methods that calculate and return useful data from processing + input multimodals (images/videos). + Note that this dataclass is aimed to be used only in vLLM + and we might change its API in the future. + """ + + num_image_tokens: list[int] = None + num_video_tokens: list[int] = None + num_audio_tokens: list[int] = None + num_image_patches: list[int] = None + + def __contains__(self, key): + return hasattr(self, key) and getattr(self, key) is not None + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + raise AttributeError(f"{self.__class__.__name__} has no attribute {key}") class ProcessorMixin(PushToHubMixin): @@ -435,20 +502,27 @@ class ProcessorMixin(PushToHubMixin): """ attributes = ["feature_extractor", "tokenizer"] - optional_attributes = ["chat_template"] - optional_call_args: List[str] = [] + optional_attributes = ["chat_template", "audio_tokenizer"] + optional_call_args: list[str] = [] # Names need to be attr_class for attr in attributes feature_extractor_class = None tokenizer_class = None _auto_class = None - valid_kwargs: List[str] = [] + valid_kwargs: list[str] = [] # args have to match the attributes class attribute def __init__(self, *args, **kwargs): # First, extract optional attributes from kwargs if present # Optional attributes can never be positional arguments for optional_attribute in self.optional_attributes: - setattr(self, optional_attribute, kwargs.pop(optional_attribute, None)) + optional_attribute_value = kwargs.pop(optional_attribute, None) + setattr(self, optional_attribute, optional_attribute_value) + + # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights + if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None: + proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value) + + # Sanitize args and kwargs for key in kwargs: if key not in self.attributes: @@ -467,40 +541,36 @@ def __init__(self, *args, **kwargs): # Check each arg is of the proper class (this will also catch a user initializing in the wrong order) for attribute_name, arg in kwargs.items(): - class_name = getattr(self, f"{attribute_name}_class") - # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class. - class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name) - if isinstance(class_name, tuple): - if "ImageProcess" in class_name[0]: - sub_path = os.path.abspath(os.path.dirname(__file__)) - sub_path = str(Path(sub_path).parent) - sys.path.insert(0, sub_path) - module_name = importlib.import_module("mindone.transformers") - proper_class = tuple(getattr(module_name, n) for n in class_name if n is not None) - else: - proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None) - elif "ImageProcess" in class_name: - sub_path = os.path.abspath(os.path.dirname(__file__)) - sub_path = str(Path(sub_path).parent) - sys.path.insert(0, sub_path) - module_name = importlib.import_module("mindone.transformers") - proper_class = getattr(module_name, class_name) - else: - proper_class = getattr(transformers_module, class_name) + self.check_argument_for_proper_class(attribute_name, arg) + setattr(self, attribute_name, arg) - if not isinstance(arg, proper_class): - raise TypeError( - f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected." - ) + def check_argument_for_proper_class(self, argument_name, argument): + """ + Checks the passed argument's class against the expected transformers class. In case of an unexpected + mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class + is returned. + """ + class_name = getattr(self, f"{argument_name}_class") + # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class. + class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name) + if isinstance(class_name, tuple): + proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None) + else: + proper_class = self.get_possibly_dynamic_module(class_name) - setattr(self, attribute_name, arg) + if not isinstance(argument, proper_class): + raise TypeError( + f"Received a {type(argument).__name__} for argument {argument_name}, but a {class_name} was expected." + ) + + return proper_class - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this processor instance. + `dict[str, Any]`: dictionary of all the attributes that make up this processor instance. """ output = copy.deepcopy(self.__dict__) @@ -521,10 +591,14 @@ def to_dict(self) -> Dict[str, Any]: del output["tokenizer"] if "image_processor" in output: del output["image_processor"] + if "video_processor" in output: + del output["video_processor"] if "feature_extractor" in output: del output["feature_extractor"] if "chat_template" in output: del output["chat_template"] + if "audio_tokenizer" in output: + del output["audio_tokenizer"] # Some attributes have different names but containing objects that are not simple strings output = { @@ -583,7 +657,7 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - kwargs (`Dict[str, Any]`, *optional*): + kwargs (`dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ use_auth_token = kwargs.pop("use_auth_token", None) @@ -614,13 +688,19 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): configs.append(self) custom_object_save(self, save_directory, config=configs) + save_jinja_files = kwargs.get("save_jinja_files", True) + for attribute_name in self.attributes: attribute = getattr(self, attribute_name) # Include the processor class in the attribute config so this processor can then be reloaded with the # `AutoProcessor` API. if hasattr(attribute, "_set_processor_class"): attribute._set_processor_class(self.__class__.__name__) - attribute.save_pretrained(save_directory) + if attribute_name == "tokenizer": + # Propagate save_jinja_files to tokenizer to ensure we don't get conflicts + attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files) + else: + attribute.save_pretrained(save_directory) if self._auto_class is not None: # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up. @@ -632,18 +712,66 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): # If we save using the predefined names, we can load using `from_pretrained` # plus we save chat_template in its own file output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) - output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME) + output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE) + output_chat_template_file_legacy = os.path.join( + save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE + ) # Legacy filename + chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR) + output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME) processor_dict = self.to_dict() # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` # to avoid serializing chat template in json config file. So let's get it from `self` directly if self.chat_template is not None: - chat_template_json_string = ( - json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" - ) - with open(output_chat_template_file, "w", encoding="utf-8") as writer: - writer.write(chat_template_json_string) - logger.info(f"chat template saved in {output_chat_template_file}") + save_jinja_files = kwargs.get("save_jinja_files", True) + is_single_template = isinstance(self.chat_template, str) + if save_jinja_files and is_single_template: + # New format for single templates is to save them as chat_template.jinja + with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f: + f.write(self.chat_template) + logger.info(f"chat template saved in {output_chat_template_file_jinja}") + elif save_jinja_files and not is_single_template: + # New format for multiple templates is to save the default as chat_template.jinja + # and the other templates in the chat_templates/ directory + for template_name, template in self.chat_template.items(): + if template_name == "default": + with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f: + f.write(self.chat_template["default"]) + logger.info(f"chat template saved in {output_chat_template_file_jinja}") + else: + os.makedirs(chat_template_dir, exist_ok=True) + template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja") + with open(template_filepath, "w", encoding="utf-8") as f: + f.write(template) + logger.info(f"chat template saved in {template_filepath}") + elif is_single_template: + # Legacy format for single templates: Put them in chat_template.json + chat_template_json_string = ( + json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" + ) + with open(output_chat_template_file_legacy, "w", encoding="utf-8") as writer: + writer.write(chat_template_json_string) + logger.info(f"chat template saved in {output_chat_template_file_legacy}") + elif self.chat_template is not None: + # At this point we have multiple templates in the legacy format, which is not supported + # chat template dicts are saved to chat_template.json as lists of dicts with fixed key names. + raise ValueError( + "Multiple chat templates are not supported in the legacy format. Please save them as " + "separate files using the `save_jinja_files` argument." + ) + + if self.audio_tokenizer is not None: + audio_tokenizer_class = self.audio_tokenizer.__class__.__name__ + audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path + + audio_tokenizer_dict = { + "audio_tokenizer_class": audio_tokenizer_class, + "audio_tokenizer_name_or_path": audio_tokenizer_name_or_path, + } + audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n" + + with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer: + writer.write(audio_tokenizer_json) # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and # `auto_map` is not specified. @@ -667,7 +795,7 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): @classmethod def get_processor_dict( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ) -> tuple[dict[str, Any], dict[str, Any]]: """ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a processor of type [`~processing_utils.ProcessingMixin`] using `from_args_and_dict`. @@ -680,8 +808,11 @@ def get_processor_dict( specify the folder name here. Returns: - `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object. + `tuple[dict, dict]`: The dictionary(ies) that will be used to instantiate the processor object. """ + # holding a copy for optionally loading the audio tokenizer (if available) + audio_tokenizer_kwargs = copy.deepcopy(kwargs) + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", None) @@ -706,21 +837,43 @@ def get_processor_dict( is_local = os.path.isdir(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME) - chat_template_file = os.path.join(pretrained_model_name_or_path, "chat_template.json") + additional_chat_template_files = {} + resolved_additional_chat_template_files = {} if os.path.isfile(pretrained_model_name_or_path): resolved_processor_file = pretrained_model_name_or_path - # cant't load chat-template when given a file as pretrained_model_name_or_path + # can't load chat-template and audio tokenizer when given a file as pretrained_model_name_or_path resolved_chat_template_file = None + resolved_raw_chat_template_file = None + resolved_audio_tokenizer_file = None is_local = True elif is_remote_url(pretrained_model_name_or_path): processor_file = pretrained_model_name_or_path resolved_processor_file = download_url(pretrained_model_name_or_path) - # can't load chat-template when given a file url as pretrained_model_name_or_path + # can't load chat-template and audio tokenizer when given a file url as pretrained_model_name_or_path resolved_chat_template_file = None + resolved_raw_chat_template_file = None + resolved_audio_tokenizer_file = None else: + if is_local: + template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR) + if template_dir.is_dir(): + for template_file in template_dir.glob("*.jinja"): + template_name = template_file.stem + additional_chat_template_files[template_name] = f"{CHAT_TEMPLATE_DIR}/{template_file.name}" + else: + try: + for template in list_repo_templates( + pretrained_model_name_or_path, + local_files_only=local_files_only, + revision=revision, + cache_dir=cache_dir, + ): + additional_chat_template_files[template] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja" + except EntryNotFoundError: + pass # No template dir means no template files processor_file = PROCESSOR_NAME - chat_template_file = CHAT_TEMPLATE_NAME + try: # Load from local folder or from cache or download from model Hub and cache resolved_processor_file = cached_file( @@ -738,12 +891,11 @@ def get_processor_dict( _raise_exceptions_for_missing_entries=False, ) - # Load chat template from a separate json if exists - # because making it part of processor-config break BC. - # Processors in older version do not accept any kwargs + # chat_template.json is a legacy file used by the processor class + # a raw chat_template.jinja is preferred in future resolved_chat_template_file = cached_file( pretrained_model_name_or_path, - chat_template_file, + LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -755,13 +907,61 @@ def get_processor_dict( subfolder=subfolder, _raise_exceptions_for_missing_entries=False, ) - except EnvironmentError: + + resolved_raw_chat_template_file = cached_file( + pretrained_model_name_or_path, + CHAT_TEMPLATE_FILE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + + resolved_additional_chat_template_files = { + template_name: cached_file( + pretrained_model_name_or_path, + template_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + for template_name, template_file in additional_chat_template_files.items() + } + + resolved_audio_tokenizer_file = cached_file( + pretrained_model_name_or_path, + AUDIO_TOKENIZER_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + except OSError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. raise except Exception: # For any other exception, we throw a generic error. - raise EnvironmentError( + raise OSError( f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load" " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" @@ -769,30 +969,69 @@ def get_processor_dict( ) # Add chat template as kwarg before returning because most models don't have processor config - chat_template = None if resolved_chat_template_file is not None: - with open(resolved_chat_template_file, "r", encoding="utf-8") as reader: - text = reader.read() - chat_template = json.loads(text)["chat_template"] - kwargs["chat_template"] = chat_template + # This is the legacy path + with open(resolved_chat_template_file, encoding="utf-8") as reader: + chat_template_json = json.loads(reader.read()) + chat_templates = {"default": chat_template_json["chat_template"]} + if resolved_additional_chat_template_files: + raise ValueError( + "Cannot load chat template due to conflicting files - this checkpoint combines " + "a legacy chat_template.json file with separate template files, which is not " + "supported. To resolve this error, replace the legacy chat_template.json file " + "with a modern chat_template.jinja file." + ) + else: + chat_templates = { + template_name: open(template_file, "r", encoding="utf-8").read() + for template_name, template_file in resolved_additional_chat_template_files.items() + } + if resolved_raw_chat_template_file is not None: + with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader: + chat_templates["default"] = reader.read() + if isinstance(chat_templates, dict) and "default" in chat_templates and len(chat_templates) == 1: + chat_templates = chat_templates["default"] # Flatten when we just have a single template/file + + if chat_templates: + kwargs["chat_template"] = chat_templates + + # Same as chat template, adding as kwarg after loading the model + audio_tokenizer = None + if resolved_audio_tokenizer_file is not None: + with open(resolved_audio_tokenizer_file, "r", encoding="utf-8") as reader: + # The json contains the references we need to init the correct model + audio_tokenizer_references = json.load(reader) + audio_tokenizer_class = cls.get_possibly_dynamic_module( + audio_tokenizer_references["audio_tokenizer_class"] + ) + audio_tokenizer_path = audio_tokenizer_references["audio_tokenizer_name_or_path"] + + audio_tokenizer = audio_tokenizer_class.from_pretrained(audio_tokenizer_path, **audio_tokenizer_kwargs) + + if audio_tokenizer is not None: + kwargs["audio_tokenizer"] = audio_tokenizer # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict. # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception) # However, for models added in the future, we won't get the expected error if this file is missing. if resolved_processor_file is None: - return {}, kwargs + # In any case we need to pass `chat_template` if it is available + processor_dict = {} + if "chat_template" in kwargs: + processor_dict["chat_template"] = kwargs.pop("chat_template") + if "audio_tokenizer" in kwargs: + processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer") + return processor_dict, kwargs try: # Load processor dict - with open(resolved_processor_file, "r", encoding="utf-8") as reader: + with open(resolved_processor_file, encoding="utf-8") as reader: text = reader.read() processor_dict = json.loads(text) except json.JSONDecodeError: - raise EnvironmentError( - f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file." - ) + raise OSError(f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file.") if is_local: logger.info(f"loading configuration file {resolved_processor_file}") @@ -801,33 +1040,28 @@ def get_processor_dict( if "chat_template" in processor_dict and processor_dict["chat_template"] is not None: logger.warning_once( - "Chat templates should be in a 'chat_template.json' file but found key='chat_template' " + "Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' " "in the processor's config. Make sure to move your template to its own file." ) - if not is_local: - if "auto_map" in processor_dict: - processor_dict["auto_map"] = add_model_info_to_auto_map( - processor_dict["auto_map"], pretrained_model_name_or_path - ) - if "custom_pipelines" in processor_dict: - processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( - processor_dict["custom_pipelines"], pretrained_model_name_or_path - ) + if "chat_template" in kwargs: + processor_dict["chat_template"] = kwargs.pop("chat_template") + if "audio_tokenizer" in kwargs: + processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer") return processor_dict, kwargs @classmethod - def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): + def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs): """ Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters. Args: - processor_dict (`Dict[str, Any]`): + processor_dict (`dict[str, Any]`): Dictionary that will be used to instantiate the processor object. Such a dictionary can be retrieved from a pretrained checkpoint by leveraging the [`~processing_utils.ProcessingMixin.to_dict`] method. - kwargs (`Dict[str, Any]`): + kwargs (`dict[str, Any]`): Additional parameters from which to initialize the processor object. Returns: @@ -836,7 +1070,6 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): """ processor_dict = processor_dict.copy() return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - chat_template = kwargs.pop("chat_template", None) # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs # If we don't pop, some specific kwargs will raise a warning @@ -846,29 +1079,40 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): if "auto_map" in processor_dict: del processor_dict["auto_map"] - unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs) - processor = cls(*args, **processor_dict) - if chat_template is not None: - setattr(processor, "chat_template", chat_template) + # override processor_dict with given kwargs + processor_dict.update(kwargs) + + # check if there is an overlap between args and processor_dict + accepted_args_and_kwargs = cls.__init__.__code__.co_varnames[: cls.__init__.__code__.co_argcount][1:] + + # validate both processor_dict and given kwargs + unused_kwargs, valid_kwargs = cls.validate_init_kwargs( + processor_config=processor_dict, valid_kwargs=accepted_args_and_kwargs + ) + + # update args that are already in processor_dict to avoid duplicate arguments + args_to_update = { + i: valid_kwargs.pop(arg) + for i, arg in enumerate(accepted_args_and_kwargs) + if (arg in valid_kwargs and i < len(args)) + } + args = [arg if i not in args_to_update else args_to_update[i] for i, arg in enumerate(args)] - # Update processor with kwargs if needed - for key in set(kwargs.keys()): - if hasattr(processor, key): - setattr(processor, key, kwargs.pop(key)) + # instantiate processor with used (and valid) kwargs only + processor = cls(*args, **valid_kwargs) - kwargs.update(unused_kwargs) logger.info(f"Processor {processor}") if return_unused_kwargs: - return processor, kwargs + return processor, unused_kwargs else: return processor def _merge_kwargs( self, ModelProcessorKwargs: ProcessingKwargs, - tokenizer_init_kwargs: Optional[Dict] = None, + tokenizer_init_kwargs: Optional[dict] = None, **kwargs, - ) -> Dict[str, Dict]: + ) -> dict[str, dict]: """ Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance. The order of operations is as follows: @@ -885,7 +1129,7 @@ def _merge_kwargs( ```python tokenizer = tokenizer_class(..., {"padding": "max_length"}) image_processor = image_processor_class(...) - processor(tokenizer, image_processor) # will pass max_length unless overriden by kwargs at call + processor(tokenizer, image_processor) # will pass max_length unless overridden by kwargs at call ``` 4) defaults kwargs specified at processor level have lowest priority. ```python @@ -925,15 +1169,16 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg "common_kwargs": {}, } + possible_modality_keywords = {"text", "audio", "videos", "images"} used_keys = set() # get defaults from set model processor kwargs if they exist - for modality in default_kwargs: + for modality in default_kwargs: # noqa: PLC0206 default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy() # update defaults with arguments from tokenizer init for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys(): # init with tokenizer init kwargs if necessary - if modality_key in tokenizer_init_kwargs: + if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs: value = ( getattr(self.tokenizer, modality_key) if hasattr(self.tokenizer, modality_key) @@ -946,7 +1191,7 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg # update modality kwargs with passed kwargs non_modality_kwargs = set(kwargs) - set(output_kwargs) - for modality in output_kwargs: + for modality, output_kwarg in output_kwargs.items(): for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys(): # check if we received a structured kwarg dict or not to handle it correctly if modality in kwargs: @@ -963,8 +1208,8 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg kwarg_value = kwargs.get(modality_key, "__empty__") else: kwarg_value = "__empty__" - if kwarg_value != "__empty__": - output_kwargs[modality][modality_key] = kwarg_value + if not isinstance(kwarg_value, str) or kwarg_value != "__empty__": + output_kwarg[modality_key] = kwarg_value used_keys.add(modality_key) # Determine if kwargs is a flat dictionary or contains nested dictionaries @@ -978,18 +1223,23 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg used_keys.add(subkey) else: # kwargs is a flat dictionary - for key in kwargs: + for key, kwarg in kwargs.items(): if key not in used_keys: - output_kwargs["common_kwargs"][key] = kwargs[key] + if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__.keys(): + output_kwargs["common_kwargs"][key] = kwarg + elif key not in possible_modality_keywords: + logger.warning_once( + f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored." + ) # all modality-specific kwargs are updated with common kwargs - for modality in output_kwargs: - output_kwargs[modality].update(output_kwargs["common_kwargs"]) + for kwarg in output_kwargs.values(): + kwarg.update(output_kwargs["common_kwargs"]) return output_kwargs @classmethod def from_pretrained( - cls, + cls: type[SpecificProcessorType], pretrained_model_name_or_path: Union[str, os.PathLike], cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, @@ -997,7 +1247,7 @@ def from_pretrained( token: Optional[Union[str, bool]] = None, revision: str = "main", **kwargs, - ): + ) -> SpecificProcessorType: r""" Instantiate a processor associated with a pretrained model. @@ -1057,11 +1307,6 @@ def register_for_auto_class(cls, auto_class="AutoProcessor"): Register this class with a given auto class. This should only be used for custom feature extractors as the ones in the library are already mapped with `AutoProcessor`. - - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`): @@ -1114,8 +1359,10 @@ def get_possibly_dynamic_module(module_name): return getattr(transformers_module, module_name) lookup_locations = [ transformers_module.IMAGE_PROCESSOR_MAPPING, + transformers_module.VIDEO_PROCESSOR_MAPPING, transformers_module.TOKENIZER_MAPPING, transformers_module.FEATURE_EXTRACTOR_MAPPING, + transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING, ] for lookup_location in lookup_locations: for custom_class in lookup_location._extra_content.values(): @@ -1139,112 +1386,21 @@ def model_input_names(self): @staticmethod def validate_init_kwargs(processor_config, valid_kwargs): - kwargs_from_config = processor_config.keys() - unused_kwargs = {} - unused_keys = set(kwargs_from_config) - set(valid_kwargs) - if unused_keys: - unused_key_str = ", ".join(unused_keys) - logger.warning( - f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. " - ) - unused_kwargs = {k: processor_config[k] for k in unused_keys} - return unused_kwargs - - def prepare_and_validate_optional_call_args(self, *args): - """ - Matches optional positional arguments to their corresponding names in `optional_call_args` - in the processor class in the order they are passed to the processor call. - - Note that this should only be used in the `__call__` method of the processors with special - arguments. Special arguments are arguments that aren't `text`, `images`, `audio`, nor `videos` - but also aren't passed to the tokenizer, image processor, etc. Examples of such processors are: - - `CLIPSegProcessor` - - `LayoutLMv2Processor` - - `OwlViTProcessor` - - Also note that passing by position to the processor call is now deprecated and will be disallowed - in future versions. We only have this for backward compatibility. - - Example: - Suppose that the processor class has `optional_call_args = ["arg_name_1", "arg_name_2"]`. - And we define the call method as: - ```python - def __call__( - self, - text: str, - images: Optional[ImageInput] = None, - *arg, - audio=None, - videos=None, - ) - ``` - - Then, if we call the processor as: - ```python - images = [...] - processor("What is common in these images?", images, arg_value_1, arg_value_2) - ``` - - Then, this method will return: - ```python - { - "arg_name_1": arg_value_1, - "arg_name_2": arg_value_2, - } - ``` - which we could then pass as kwargs to `self._merge_kwargs` - """ - if len(args): - warnings.warn( - "Passing positional arguments to the processor call is now deprecated and will be disallowed in v4.47. " - "Please pass all arguments as keyword arguments." - ) - if len(args) > len(self.optional_call_args): - raise ValueError( - f"Expected *at most* {len(self.optional_call_args)} optional positional arguments in processor call" - f"which will be matched with {' '.join(self.optional_call_args)} in the order they are passed." - f"However, got {len(args)} positional arguments instead." - "Please pass all arguments as keyword arguments instead (e.g. `processor(arg_name_1=..., arg_name_2=...))`." - ) - return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)} + kwargs_from_config = set(processor_config.keys()) + valid_kwargs_set = set(valid_kwargs) - def _process_messages_for_chat_template( - self, - conversation: List[List[Dict[str, str]]], - batch_images: List[ImageInput], - batch_videos: List[VideoInput], - batch_video_metadata: List[List[Dict[str, any]]], - **chat_template_kwargs: Unpack[AllKwargsForChatTemplate], - ): - """ - Used within `apply_chat_template` when a model has a special way to process conversation history. For example, - video models might want to specify in the prompt the duration of video or which frame indices at which timestamps - were sampled. This information cannot be accessed before the video is loaded. + unused_keys = kwargs_from_config - valid_kwargs_set + valid_keys = kwargs_from_config & valid_kwargs_set - For most models it is a no-op, and must be overridden by model processors which require special processing. - - Args: - conversation (`List[Dict, str, str]`): - The conversation to process. Always comes in batched format. - batch_images (`List[List[ImageInput]]`): - Batch of images that were loaded from url/path defined in the conversation. The images - are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images - per batch. - batch_videos (`List[List[ImageInput]]`): - Batch of videos that were loaded from url/path defined in the conversation. The videos - are ordered in the samm way as in the conversation. Comes in nested list format, one list of 4D video arrays - per batch. - batch_video_metadata (`List[List[Dict[[str, any]]]]`): - Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video. - Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays - per batch. + unused_kwargs = {k: processor_config[k] for k in unused_keys} if unused_keys else {} + valid_kwargs = {k: processor_config[k] for k in valid_keys} if valid_keys else {} - """ - return conversation + return unused_kwargs, valid_kwargs + @deprecate_kwarg("video_fps", version="4.58", new_name="fps") def apply_chat_template( self, - conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], chat_template: Optional[str] = None, **kwargs: Unpack[AllKwargsForChatTemplate], ) -> str: @@ -1260,43 +1416,77 @@ def apply_chat_template( { "role": "user", "content": [ - {"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, {"type": "text", "text": "Please describe this image in detail."}, ], }, ] Args: - conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`): + conversation (`Union[list[Dict, [str, str]], list[list[dict[str, str]]]]`): The conversation to format. chat_template (`Optional[str]`, *optional*): The Jinja template to use for formatting the conversation. If not provided, the tokenizer's chat template is used. """ - if chat_template is None: - if self.chat_template is not None: + if isinstance(self.chat_template, dict) and "default" in self.chat_template: + chat_template = self.chat_template["default"] + elif isinstance(self.chat_template, dict): + raise ValueError( + 'The processor has multiple chat templates but none of them are named "default". You need to specify' + " which one to use by passing the `chat_template` argument. Available templates are: " + f"{', '.join(self.chat_template.keys())}" + ) + elif self.chat_template is not None: chat_template = self.chat_template else: raise ValueError( - "No chat template is set for this processor. Please either set the `chat_template` attribute, " - "or provide a chat template as an argument. See " - "https://huggingface.co/docs/transformers/main/en/chat_templating for more information." + "Cannot use apply_chat_template because this processor does not have a chat template." + ) + else: + if isinstance(self.chat_template, dict) and chat_template in self.chat_template: + # It's the name of a template, not a full template string + chat_template = self.chat_template[chat_template] + else: + # It's a template string, render it directly + chat_template = chat_template + + is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast") + + if kwargs.get("continue_final_message", False): + if kwargs.get("add_generation_prompt", False): + raise ValueError( + "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." + ) + if kwargs.get("return_assistant_tokens_mask", False): + raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") + + if kwargs.get("return_assistant_tokens_mask", False): + if not is_tokenizers_fast: + raise ValueError( + "`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. " + "If the error persists, open an issue to support a Fast tokenizer for your model." ) + else: + kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries - # Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template` - # and for multimodal chat template - tokenizer_template_kwargs = {} - for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys(): - tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None) - value = kwargs.pop(tokenizer_key, tokenizer_value) - tokenizer_template_kwargs[tokenizer_key] = value + # Fill sets of kwargs that should be used by different parts of template + processed_kwargs = { + "mm_load_kwargs": {}, + "template_kwargs": {}, + } + + for kwarg_type in processed_kwargs: + for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__.keys(): + kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type] + default_value = getattr(kwarg_type_defaults, key, None) + value = kwargs.pop(key, default_value) + if value is not None and not isinstance(value, dict): + processed_kwargs[kwarg_type][key] = value - chat_template_kwargs = {} - for key in ProcessorChatTemplateKwargs.__annotations__.keys(): - processor_value = getattr(ProcessorChatTemplateKwargs, key, None) - value = kwargs.pop(key, processor_value) - chat_template_kwargs[key] = value + # Pass unprocessed custom kwargs + processed_kwargs["template_kwargs"].update(kwargs) if isinstance(conversation, (list, tuple)) and ( isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content") @@ -1307,21 +1497,25 @@ def apply_chat_template( is_batched = False conversations = [conversation] - num_frames = chat_template_kwargs.get("num_frames") - video_fps = chat_template_kwargs.get("video_fps") - video_load_backend = chat_template_kwargs.get("video_load_backend") - tokenize = chat_template_kwargs.get("tokenize") - return_dict = chat_template_kwargs.get("return_dict") - sample_indices_fn = chat_template_kwargs.get("sample_indices_fn") + tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False) + return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False) + mm_load_kwargs = processed_kwargs["mm_load_kwargs"] if tokenize: batch_images, batch_videos = [], [] + batch_audios = [] batch_video_metadata = [] for conversation in conversations: images, videos = [], [] video_metadata = [] for message in conversation: visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] + audio_fnames = [ + content[key] + for content in message["content"] + for key in ["audio", "url", "path"] + if key in content and content["type"] == "audio" + ] image_fnames = [ vision_info[key] for vision_info in visuals @@ -1334,25 +1528,32 @@ def apply_chat_template( for key in ["video", "url", "path"] if key in vision_info and vision_info["type"] == "video" ] + for fname in image_fnames: images.append(load_image(fname)) + + # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list + if not mm_load_kwargs["load_audio_from_video"]: + for fname in audio_fnames: + batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])) + else: + for fname in video_fnames: + batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])) + for fname in video_fnames: if isinstance(fname, (list, tuple)) and isinstance(fname[0], str): - video = [np.array(load_image(image_fname)).T for image_fname in fname] + video = [np.array(load_image(image_fname)) for image_fname in fname] # create a 4D video because `load_video` always returns a 4D array video = np.stack(video) metadata = None logger.warning( "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. " - "If you model applies special processing based on metadata, please load the whole video and let the model sample frames." + "If your model requires metadata during processing, please load the whole video and let the processor sample frames instead." ) else: video, metadata = load_video( fname, - num_frames=num_frames, - fps=video_fps, - backend=video_load_backend, - sample_indices_fn=sample_indices_fn, + backend=mm_load_kwargs["video_load_backend"], ) videos.append(video) video_metadata.append(metadata) @@ -1365,21 +1566,11 @@ def apply_chat_template( batch_videos.append(videos) batch_video_metadata.append(video_metadata) - # Process conversation with video/image information if needed. Then convert into a prompt using Jinja template - conversations = self._process_messages_for_chat_template( - conversations, - batch_images=batch_images, - batch_videos=batch_videos, - batch_video_metadata=batch_video_metadata, - **chat_template_kwargs, - ) - - prompt = self.tokenizer.apply_chat_template( - conversations, + prompt, generation_indices = render_jinja_template( + conversations=conversations, chat_template=chat_template, - tokenize=False, - return_dict=False, - **tokenizer_template_kwargs, + **processed_kwargs["template_kwargs"], # different flags such as `return_assistant_mask` + **self.tokenizer.special_tokens_map, # tokenizer special tokens are used by some templates ) if not is_batched: @@ -1396,13 +1587,44 @@ def apply_chat_template( if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token): kwargs["add_special_tokens"] = False + # Always sample frames by default unless explicitly set to `False` by users. If users do not pass `num_frames`/`video_fps` + # sampling should not done for BC. + if "do_sample_frames" not in kwargs and ("fps" in kwargs or "num_frames" in kwargs): + kwargs["do_sample_frames"] = True + out = self( text=prompt, images=batch_images if batch_images else None, videos=batch_videos if batch_videos else None, + audio=batch_audios if batch_audios else None, + video_metadata=batch_video_metadata, **kwargs, ) + if return_dict: + if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False): + assistant_masks = [] + offset_mapping = out.pop("offset_mapping") + input_ids = out["input_ids"] + for i in range(len(input_ids)): + current_mask = [0] * len(input_ids[i]) + offsets = offset_mapping[i] + offset_starts = [start for start, end in offsets] + for assistant_start_char, assistant_end_char in generation_indices[i]: + start_pos = bisect.bisect_left(offset_starts, assistant_start_char) + end_pos = bisect.bisect_left(offset_starts, assistant_end_char) + + if not ( + start_pos >= 0 + and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1] + ): + # start_token is out of bounds maybe due to truncation. + continue + for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])): + current_mask[token_id] = 1 + assistant_masks.append(current_mask) + out["assistant_masks"] = assistant_masks + out.convert_to_tensors(tensor_type=kwargs.get("return_tensors", None)) return out else: return out["input_ids"] @@ -1422,68 +1644,27 @@ def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens Additional arguments to be passed to the tokenizer's `batch_decode method`. Returns: - `List[str]`: The decoded text. + `list[str]`: The decoded text. """ return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) -def _validate_images_text_input_order(images, text): - """ - For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped. - This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes. - Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled - in the processor's `__call__` method before calling this method. - """ - - def is_url(val) -> bool: - return isinstance(val, str) and val.startswith("http") - - def _is_valid_images_input_for_processor(imgs): - # If we have an list of images, make sure every image is valid - if isinstance(imgs, (list, tuple)): - for img in imgs: - if not _is_valid_images_input_for_processor(img): - return False - # If not a list or tuple, we have been given a single image or batched tensor of images - elif not (is_valid_image(imgs) or is_url(imgs)): - return False - return True - - def _is_valid_text_input_for_processor(t): - if isinstance(t, str): - # Strings are fine - return True - elif isinstance(t, (list, tuple)): - # List are fine as long as they are... - if len(t) == 0: - # ... not empty - return False - for t_s in t: - return _is_valid_text_input_for_processor(t_s) - return False - - def _is_valid(input, validator): - return validator(input) or input is None - - images_is_valid = _is_valid(images, _is_valid_images_input_for_processor) - images_is_text = _is_valid_text_input_for_processor(images) - - text_is_valid = _is_valid(text, _is_valid_text_input_for_processor) - text_is_images = _is_valid_images_input_for_processor(text) - # Handle cases where both inputs are valid - if images_is_valid and text_is_valid: - return images, text - - # Handle cases where inputs need to and can be swapped - if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images): - logger.warning_once( - "You may have used the wrong order for inputs. `images` should be passed before `text`. " - "The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47." - ) - return text, images - - raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.") + def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", modalities: list[str]): + """ + Checks that number of special tokens in text and processed text is same. The count can be different + if tokenized text was truncated, leading to issues in model code. + """ + for modality in modalities: + token_str = getattr(self, f"{modality}_token") + token_id = getattr(self, f"{modality}_token_id") + ids_count = [list(ids).count(token_id) for ids in text_inputs["input_ids"]] + text_count = [sample.count(token_str) for sample in text] + if ids_count != text_count: + raise ValueError( + f"Mismatch in `{modality}` token count between text and `input_ids`. Got ids={ids_count} and text={text_count}. " + "Likely due to `truncation='max_length'`. Please disable truncation or increase `max_length`." + ) ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) if ProcessorMixin.push_to_hub.__doc__ is not None: From 8a90ca65202d973a0ccd056f9f1cf83b2156da65 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:21:26 +0800 Subject: [PATCH 27/94] remove add_model_info_to_auto_map & update video_utils.py --- mindone/transformers/utils/import_utils.py | 20 +- mindone/transformers/video_utils.py | 671 +++++++++++++++++++++ 2 files changed, 689 insertions(+), 2 deletions(-) create mode 100644 mindone/transformers/video_utils.py diff --git a/mindone/transformers/utils/import_utils.py b/mindone/transformers/utils/import_utils.py index 33ecca2238..83265524a0 100644 --- a/mindone/transformers/utils/import_utils.py +++ b/mindone/transformers/utils/import_utils.py @@ -64,9 +64,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ else: return package_exists - +_av_available = importlib.util.find_spec("av") is not None +_decord_available = importlib.util.find_spec("decord") is not None _scipy_available = _is_package_available("scipy") - +_cv2_available = importlib.util.find_spec("cv2") is not None +_yt_dlp_available = importlib.util.find_spec("yt_dlp") is not None def is_mindspore_available(): _mindspore_available, _mindspore_version = _is_package_available("mindspore", return_version=True) @@ -82,6 +84,20 @@ def is_scipy_available(): return _scipy_available +def is_av_available(): + return _av_available + + +def is_decord_available(): + return _decord_available + +def is_cv2_available(): + return _cv2_available + + +def is_yt_dlp_available(): + return _yt_dlp_available + @lru_cache def is_vision_available(): _pil_available = importlib.util.find_spec("PIL") is not None diff --git a/mindone/transformers/video_utils.py b/mindone/transformers/video_utils.py new file mode 100644 index 0000000000..a07e168c8f --- /dev/null +++ b/mindone/transformers/video_utils.py @@ -0,0 +1,671 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +from collections.abc import Iterable +from contextlib import redirect_stdout +from dataclasses import dataclass +from io import BytesIO +from typing import Callable, Optional, Union +from urllib.parse import urlparse + +import numpy as np +import requests + +from .image_transforms import PaddingMode, to_channel_dimension_format +from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image +from .utils import ( + is_av_available, + is_cv2_available, + is_decord_available, + is_numpy_array, + is_vision_available, + is_yt_dlp_available, + requires_backends, + is_mindspore_available, + is_mindspore_tensor, +) +from transformers.utils import logging + + +if is_vision_available(): + import PIL.Image + import PIL.ImageOps + +if is_mindspore_available(): + import mindspore + +logger = logging.get_logger(__name__) + + +VideoInput = Union[ + list["PIL.Image.Image"], + "np.ndarray", + "mindspore.Tensor", + list["np.ndarray"], + list["mindspore.Tensor"], + list[list["PIL.Image.Image"]], + list[list["np.ndarrray"]], + list[list["mindspore.Tensor"]], +] # noqa + + +@dataclass +class VideoMetadata: + total_num_frames: int + fps: float + duration: float + video_backend: str + + def __getitem__(self, item): + return getattr(self, item) + + +def is_valid_video_frame(frame): + return isinstance(frame, PIL.Image.Image) or ( + (is_numpy_array(frame) or is_mindspore_tensor(frame)) and frame.ndim == 3 + ) + + +def is_valid_video(video): + if not isinstance(video, (list, tuple)): + return (is_numpy_array(video) or is_mindspore_tensor(video)) and video.ndim == 4 + return all(is_valid_video_frame(frame) for frame in video) + + +def valid_videos(videos): + # If we have a list of videos, it could be either one video as list of frames or a batch + if isinstance(videos, (list, tuple)): + for video_or_frame in videos: + if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)): + return False + # If not a list, then we have a single 4D video or 5D batched tensor + elif not is_valid_video(videos) or videos.ndim == 5: + return False + return True + + +def is_batched_video(videos): + if isinstance(videos, (list, tuple)): + return is_valid_video(videos[0]) + elif (is_numpy_array(videos) or is_mindspore_tensor(videos)) and videos.ndim == 5: + return True + return False + + +def is_scaled_video(video: np.ndarray) -> bool: + """ + Checks to see whether the pixel values have already been rescaled to [0, 1]. + """ + # It's possible the video has pixel values in [0, 255] but is of floating type + return np.min(video) >= 0 and np.max(video) <= 1 + + +def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union["np.ndarray", "mindspore.Tensor"]]: + """ + Given a batch of videos, converts each video to a 4D array. If video is already in array type, + it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element. + + Args: + videos (`VideoInput`): + Video inputs to turn into a list of videos. + """ + + if not isinstance(videos[0], (list, tuple)): + return videos + + video_converted = [] + for video in videos: + video = [np.array(frame) for frame in video] + video = np.stack(video) + video_converted.append(video) + return video_converted + + +def make_batched_videos(videos) -> list[Union["np.ndarray", "mindspore.Tensor"]]: + """ + Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1. + If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image` + frames are converted to 4D arrays. + + We assume that all inputs in the list are in the same format, based on the type of the first element. + + Args: + videos (`VideoInput`): + Video inputs to turn into a list of videos. + """ + if not valid_videos: + raise ValueError( + f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got" + f" type {type(videos)}." + ) + + if is_batched_video(videos): + pass + elif is_valid_video(videos): + videos = [videos] + # only one frame passed, thus we unsqueeze time dim + elif is_valid_image(videos): + videos = [np.array(videos)[None, ...]] + # nested batch so we need to unflatten + elif isinstance(videos[0], (list, tuple)) and is_valid_video(videos[0][0]): + videos = [video for sublist in videos for video in sublist] + return convert_pil_frames_to_video(videos) + + +def get_video_size(video: np.ndarray, channel_dim: ChannelDimension = None) -> tuple[int, int]: + """ + Returns the (height, width) dimensions of the video. + + Args: + video (`np.ndarray`): + The video to get the dimensions of. + channel_dim (`ChannelDimension`, *optional*): + Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video. + + Returns: + A tuple of the video's height and width. + """ + if channel_dim is None: + channel_dim = infer_channel_dimension_format(video) + + if channel_dim == ChannelDimension.FIRST: + return video.shape[-2], video.shape[-1] + elif channel_dim == ChannelDimension.LAST: + return video.shape[-3], video.shape[-2] + else: + raise ValueError(f"Unsupported data format: {channel_dim}") + + +def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None): + """ + Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames` + when loading a video. + + Args: + total_num_frames (`int`): + Total number of frames that a video has. + num_frames (`int`, *optional*): + Number of frames to sample uniformly. If not specified, all frames are sampled. + + Returns: + np.ndarray: np array of frame indices that will be sampled. + """ + if num_frames is not None: + indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int) + else: + indices = np.arange(0, total_num_frames).astype(int) + return indices + + +def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs): + """ + A default sampling function that replicates the logic used in get_uniform_frame_indices, + while optionally handling `fps` if `num_frames` is not provided. + + Args: + metadata (`VideoMetadata`): + `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps". + num_frames (`int`, *optional*): + Number of frames to sample uniformly. + fps (`int` or `float`, *optional*): + Desired frames per second. Takes priority over num_frames if both are provided. + + Returns: + `np.ndarray`: Array of frame indices to sample. + """ + total_num_frames = metadata.total_num_frames + video_fps = metadata.fps + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is None and fps is not None: + num_frames = int(total_num_frames / video_fps * fps) + if num_frames > total_num_frames: + raise ValueError( + f"When loading the video with fps={fps}, we computed num_frames={num_frames} " + f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata." + ) + + if num_frames is not None: + indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int) + else: + indices = np.arange(0, total_num_frames, dtype=int) + return indices + + +def read_video_opencv( + video_path: str, + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode a video using the OpenCV backend. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import cv2 + requires_backends(read_video_opencv, ["cv2"]) + import cv2 + + video = cv2.VideoCapture(video_path) + total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + video_fps = video.get(cv2.CAP_PROP_FPS) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv" + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) + + index = 0 + frames = [] + while video.isOpened(): + success, frame = video.read() + if not success: + break + if index in indices: + height, width, channel = frame.shape + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame[0:height, 0:width, 0:channel]) + if success: + index += 1 + if index >= total_num_frames: + break + + video.release() + metadata.frames_indices = indices + return np.stack(frames), metadata + + +def read_video_decord( + video_path: str, + sample_indices_fn: Optional[Callable] = None, + **kwargs, +): + """ + Decode a video using the Decord backend. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import from decord + requires_backends(read_video_decord, ["decord"]) + from decord import VideoReader, cpu + + vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu + video_fps = vr.get_avg_fps() + total_num_frames = len(vr) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord" + ) + + indices = sample_indices_fn(metadata=metadata, **kwargs) + + frames = vr.get_batch(indices).asnumpy() + metadata.frames_indices = indices + return frames, metadata + + +def read_video_pyav( + video_path: str, + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode the video with PyAV decoder. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import av + requires_backends(read_video_pyav, ["av"]) + import av + + container = av.open(video_path) + total_num_frames = container.streams.video[0].frames + video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`? + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav" + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) + + frames = [] + container.seek(0) + end_index = indices[-1] + for i, frame in enumerate(container.decode(video=0)): + if i > end_index: + break + if i >= 0 and i in indices: + frames.append(frame) + + video = np.stack([x.to_ndarray(format="rgb24") for x in frames]) + metadata.frames_indices = indices + return video, metadata + + + + +VIDEO_DECODERS = { + "decord": read_video_decord, + "opencv": read_video_opencv, + "pyav": read_video_pyav, +} + + +def load_video( + video: Union[str, "VideoInput"], + num_frames: Optional[int] = None, + fps: Optional[Union[int, float]] = None, + backend: str = "pyav", + sample_indices_fn: Optional[Callable] = None, + **kwargs, +) -> np.array: + """ + Loads `video` to a numpy array. + + Args: + video (`str` or `VideoInput`): + The video to convert to the numpy array format. Can be a link to video or local path. + num_frames (`int`, *optional*): + Number of frames to sample uniformly. If not passed, the whole video is loaded. + fps (`int` or `float`, *optional*): + Number of frames to sample per second. Should be passed only when `num_frames=None`. + If not specified and `num_frames==None`, all frames are sampled. + backend (`str`, *optional*, defaults to `"pyav"`): + The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv",]. Defaults to "pyav". + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args. + The function expects at input the all args along with all kwargs passed to `load_video` and should output valid + indices at which the video should be sampled. For example: + + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.array`, Dict]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - Metadata dictionary. + """ + + # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn` + if fps is not None and num_frames is not None and sample_indices_fn is None: + raise ValueError( + "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!" + ) + + # If user didn't pass a sampling function, create one on the fly with default logic + if sample_indices_fn is None: + + def sample_indices_fn_func(metadata, **fn_kwargs): + return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs) + + sample_indices_fn = sample_indices_fn_func + + if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]: + if not is_yt_dlp_available(): + raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.") + # Lazy import from yt_dlp + requires_backends(load_video, ["yt_dlp"]) + from yt_dlp import YoutubeDL + + buffer = BytesIO() + with redirect_stdout(buffer), YoutubeDL() as f: + f.download([video]) + bytes_obj = buffer.getvalue() + file_obj = BytesIO(bytes_obj) + elif video.startswith("http://") or video.startswith("https://"): + file_obj = BytesIO(requests.get(video).content) + elif os.path.isfile(video): + file_obj = video + elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])): + file_obj = None + else: + raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.") + + # can also load with decord, but not cv2 + # both will fail in case of url links + video_is_url = video.startswith("http://") or video.startswith("https://") + if video_is_url and backend in ["opencv"]: + raise ValueError( + "If you are trying to load a video from URL, you can decode the video only with `pyav`, `decord` as backend" + ) + + if file_obj is None: + return video + + if ( + (not is_decord_available() and backend == "decord") + or (not is_av_available() and backend == "pyav") + or (not is_cv2_available() and backend == "opencv") + ): + raise ImportError( + f"You chose backend={backend} for loading the video but the required library is not found in your environment " + f"Make sure to install {backend} before loading the video." + ) + + video_decoder = VIDEO_DECODERS[backend] + video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs) + return video, metadata + + +def convert_to_rgb( + video: np.array, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.array: + """ + Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it. + + Args: + video (`np.array`): + The video to convert. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the output video. If unset, will use the inferred format from the input. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input video. If unset, will use the inferred format from the input. + """ + if not isinstance(video, np.ndarray): + raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}") + + # np.array usually comes with ChannelDimension.LAST so leet's convert it + if input_data_format is None: + input_data_format = infer_channel_dimension_format(video) + video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format) + + # 3 channels for RGB already + if video.shape[-3] == 3: + return video + + # Grayscale video so we repeat it 3 times for each channel + if video.shape[-3] == 1: + return video.repeat(3, -3) + + if not (video[..., 3, :, :] < 255).any(): + return video + + # There is a transparency layer, blend it with a white background. + # Calculate the alpha proportion for blending. + alpha = video[..., 3, :, :] / 255.0 + video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :] + return video + + +def pad( + video: np.ndarray, + padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]], + mode: PaddingMode = PaddingMode.CONSTANT, + constant_values: Union[float, Iterable[float]] = 0.0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Pads the `video` with the specified (height, width) `padding` and `mode`. + + Args: + video (`np.ndarray`): + The video to pad. + padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`): + Padding to apply to the edges of the height, width axes. Can be one of three formats: + - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis. + - `((before, after),)` yields same before and after pad for height and width. + - `(pad,)` or int is a shortcut for before = after = pad width for all axes. + mode (`PaddingMode`): + The padding mode to use. Can be one of: + - `"constant"`: pads with a constant value. + - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the + vector along each axis. + - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis. + - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output video. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format. + If unset, will use same as the input video. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input video. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format. + If unset, will use the inferred format of the input video. + + Returns: + `np.ndarray`: The padded video. + + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(video) + + def _expand_for_data_format(values): + """ + Convert values to be in the format expected by np.pad based on the data format. + """ + if isinstance(values, (int, float)): + values = ((values, values), (values, values)) + elif isinstance(values, tuple) and len(values) == 1: + values = ((values[0], values[0]), (values[0], values[0])) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int): + values = (values, values) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple): + values = values + else: + raise ValueError(f"Unsupported format: {values}") + + # add 0 for channel dimension + values = ( + ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0)) + ) + + # Add additional padding if there's a batch dimension + values = (0, *values) if video.ndim == 5 else values + return values + + padding_map = { + PaddingMode.CONSTANT: "constant", + PaddingMode.REFLECT: "reflect", + PaddingMode.REPLICATE: "replicate", + PaddingMode.SYMMETRIC: "symmetric", + } + padding = _expand_for_data_format(padding) + + pad_kwargs = {} + if mode not in padding_map: + raise ValueError(f"Invalid padding mode: {mode}") + elif mode == PaddingMode.CONSTANT: + pad_kwargs["constant_values"] = _expand_for_data_format(constant_values) + + video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs) + video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video + return video + + +def group_videos_by_shape( + videos: list["mindspore.Tensor"], +) -> tuple[dict[tuple[int, int], list["mindspore.Tensor"]], dict[int, tuple[tuple[int, int], int]]]: + """ + Groups videos by shape. + Returns a dictionary with the shape as key and a list of videos with that shape as value, + and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value. + """ + grouped_videos = {} + grouped_videos_index = {} + for i, video in enumerate(videos): + shape = video.shape[-2::] + num_frames = video.shape[-4] # video format BTCHW + shape = (num_frames, *shape) + if shape not in grouped_videos: + grouped_videos[shape] = [] + grouped_videos[shape].append(video) + grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1) + # stack videos with the same size and number of frames + grouped_videos = {shape: mint.stack(videos, dim=0) for shape, videos in grouped_videos.items()} + return grouped_videos, grouped_videos_index + + +def reorder_videos( + processed_videos: dict[tuple[int, int], "mindspore.Tensor"], grouped_videos_index: dict[int, tuple[int, int]] +) -> list["mindspore.Tensor"]: + """ + Reconstructs a list of videos in the original order. + """ + return [ + processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]] + for i in range(len(grouped_videos_index)) + ] From dcb98acc5a8959f39e3d27838b11a3f616463398 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:26:20 +0800 Subject: [PATCH 28/94] tokenization_utils.py update --- mindone/transformers/tokenization_utils.py | 125 ++++++++++++--------- 1 file changed, 74 insertions(+), 51 deletions(-) diff --git a/mindone/transformers/tokenization_utils.py b/mindone/transformers/tokenization_utils.py index e2673eed4d..2f4a2f8a37 100644 --- a/mindone/transformers/tokenization_utils.py +++ b/mindone/transformers/tokenization_utils.py @@ -24,7 +24,7 @@ import re import unicodedata from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union, overload +from typing import Any, Optional, Union, overload from .tokenization_utils_base import ( ENCODE_KWARGS_DOCSTRING, @@ -43,6 +43,7 @@ ) from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging + logger = logging.get_logger(__name__) # Slow tokenizers are saved in a vocabulary plus three separated files @@ -104,7 +105,7 @@ def add(self, word: str): ref = ref[char] ref[self._termination_char] = 1 - def split(self, text: str) -> List[str]: + def split(self, text: str) -> list[str]: """ Will look for the words added to the trie within `text`. Output is the original string splitted along the boundaries of the words found. @@ -317,6 +318,9 @@ def _get_node(self, token: str) -> dict: """ node = self.data for char in token: + if char not in node: + break + node = node[char] return node @@ -389,7 +393,7 @@ def _is_start_of_word(text): return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char)) -def _insert_one_token_to_ordered_list(token_list: List[str], new_token: str): +def _insert_one_token_to_ordered_list(token_list: list[str], new_token: str): """ Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted. """ @@ -423,11 +427,11 @@ def __init__(self, **kwargs): # 2. init `_added_tokens_decoder` if child class did not if not hasattr(self, "_added_tokens_decoder"): - self._added_tokens_decoder: Dict[int, AddedToken] = {} + self._added_tokens_decoder: dict[int, AddedToken] = {} # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {})) - self._added_tokens_encoder: Dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()} + self._added_tokens_encoder: dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()} # 4 init the parent class super().__init__(**kwargs) @@ -453,7 +457,7 @@ def vocab_size(self) -> int: raise NotImplementedError @property - def added_tokens_encoder(self) -> Dict[str, int]: + def added_tokens_encoder(self) -> dict[str, int]: """ Returns the sorted mapping from string to index. The added tokens encoder is cached for performance optimisation in `self._added_tokens_encoder` for the slow tokenizers. @@ -461,54 +465,61 @@ def added_tokens_encoder(self) -> Dict[str, int]: return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])} @property - def added_tokens_decoder(self) -> Dict[int, AddedToken]: + def added_tokens_decoder(self) -> dict[int, AddedToken]: """ Returns the added tokens in the vocabulary as a dictionary of index to AddedToken. Returns: - `Dict[str, int]`: The added tokens. + `dict[str, int]`: The added tokens. """ return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])) @added_tokens_decoder.setter - def added_tokens_decoder(self, value: Dict[int, Union[AddedToken, str]]) -> Dict[int, AddedToken]: + def added_tokens_decoder(self, value: dict[int, Union[AddedToken, str]]) -> dict[int, AddedToken]: # Always raise an error if string because users should define the behavior for index, token in value.items(): if not isinstance(token, (str, AddedToken)) or not isinstance(index, int): - raise ValueError( - f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, \ - should be a dict of {int, Union[AddedToken, str]}" + raise TypeError( + f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, Union[AddedToken, str]}" ) self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token self._added_tokens_encoder[str(token)] = index + self._update_total_vocab_size() - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: """ Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from the fast call because for now we always add the tokens even if they are already in the vocabulary. This is something we should change. Returns: - `Dict[str, int]`: The added tokens. + `dict[str, int]`: The added tokens. """ return self._added_tokens_encoder def __len__(self): """ - Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if - there is a hole in the vocab, we will add tokenizers at a wrong index. + Size of the full vocabulary with the added tokens. """ - return len(set(self.get_vocab().keys())) + return self.total_vocab_size - def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + def _update_total_vocab_size(self): + """ + Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because + otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and + is only updated when adding tokens. + """ + self.total_vocab_size = len(self.get_vocab()) + + def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int: """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the vocab which is why they have to be handled specifically. Args: - new_tokens (`List[str]`or `List[tokenizers.AddedToken]`): + new_tokens (`list[str]`or `list[tokenizers.AddedToken]`): Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the @@ -548,7 +559,9 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to else: # very important for fast and slow equivalence! is_special = token in self.all_special_tokens or special_tokens - token = AddedToken(token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special) + token = AddedToken( + token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special + ) elif special_tokens: # doing token.special=True changes the normalization! will fix in rust # this is important and the only reason why the AddedTokens in each class are normalized by default @@ -574,6 +587,7 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to logger.info(f"Adding {token} to the vocabulary") self._update_trie() + self._update_total_vocab_size() return added_tokens def _update_trie(self, unique_no_split_tokens: Optional[str] = []): @@ -607,7 +621,7 @@ def num_special_tokens_to_add(self, pair: bool = False) -> int: token_ids_1 = [] return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) - def tokenize(self, text: TextInput, **kwargs) -> List[str]: + def tokenize(self, text: TextInput, **kwargs) -> list[str]: """ Converts a string into a sequence of tokens, using the tokenizer. @@ -621,7 +635,7 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]: Passed along to the model-specific `prepare_for_tokenization` preprocessing method. Returns: - `List[str]`: The list of tokens. + `list[str]`: The list of tokens. """ split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens) @@ -696,16 +710,16 @@ def _tokenize(self, text, **kwargs): """ raise NotImplementedError - def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: + def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, list[int]]: """ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the vocabulary. Args: - tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s). + tokens (`str` or `list[str]`): One or several token(s) to convert to token id(s). Returns: - `int` or `List[int]`: The token id or list of token ids. + `int` or `list[int]`: The token id or list of token ids. """ if tokens is None: return None @@ -740,6 +754,7 @@ def _encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -797,6 +812,7 @@ def get_input_ids(text): max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, prepend_batch_axis=True, return_attention_mask=return_attention_mask, @@ -810,12 +826,12 @@ def get_input_ids(text): def _batch_encode_plus( self, batch_text_or_text_pairs: Union[ - List[TextInput], - List[TextInputPair], - List[PreTokenizedInput], - List[PreTokenizedInputPair], - List[EncodedInput], - List[EncodedInputPair], + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + list[PreTokenizedInputPair], + list[EncodedInput], + list[EncodedInputPair], ], add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, @@ -824,6 +840,7 @@ def _batch_encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -882,6 +899,7 @@ def get_input_ids(text): max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_attention_mask=return_attention_mask, return_token_type_ids=return_token_type_ids, return_overflowing_tokens=return_overflowing_tokens, @@ -897,13 +915,14 @@ def get_input_ids(text): @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def _batch_prepare_for_model( self, - batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]], + batch_ids_pairs: list[Union[PreTokenizedInputPair, tuple[list[int], None]]], add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[str] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -933,6 +952,7 @@ def _batch_prepare_for_model( max_length=max_length, stride=stride, pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward return_attention_mask=False, # we pad in batch afterward return_token_type_ids=return_token_type_ids, return_overflowing_tokens=return_overflowing_tokens, @@ -954,6 +974,7 @@ def _batch_prepare_for_model( padding=padding_strategy.value, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_attention_mask=return_attention_mask, ) @@ -963,7 +984,7 @@ def _batch_prepare_for_model( def prepare_for_tokenization( self, text: str, is_split_into_words: bool = False, **kwargs - ) -> Tuple[str, Dict[str, Any]]: + ) -> tuple[str, dict[str, Any]]: """ Performs any necessary transformations before tokenization. @@ -977,25 +998,25 @@ def prepare_for_tokenization( Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) which it will tokenize. This is useful for NER or token classification. - kwargs (`Dict[str, Any]`, *optional*): + kwargs (`dict[str, Any]`, *optional*): Keyword arguments to use for the tokenization. Returns: - `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs. + `tuple[str, dict[str, Any]]`: The prepared text and the unused kwargs. """ return (text, kwargs) def get_special_tokens_mask( - self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False - ) -> List[int]: + self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False + ) -> list[int]: """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. Args: - token_ids_0 (`List[int]`): + token_ids_0 (`list[int]`): List of ids of the first sequence. - token_ids_1 (`List[int]`, *optional*): + token_ids_1 (`list[int]`, *optional*): List of ids of the second sequence. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. @@ -1016,28 +1037,26 @@ def get_special_tokens_mask( return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) @overload - def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: - ... + def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ... @overload - def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: - ... + def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: ... def convert_ids_to_tokens( - self, ids: Union[int, List[int]], skip_special_tokens: bool = False - ) -> Union[str, List[str]]: + self, ids: Union[int, list[int]], skip_special_tokens: bool = False + ) -> Union[str, list[str]]: """ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and added tokens. Args: - ids (`int` or `List[int]`): + ids (`int` or `list[int]`): The token id (or token ids) to convert to tokens. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. Returns: - `str` or `List[str]`: The decoded token(s). + `str` or `list[str]`: The decoded token(s). """ if isinstance(ids, int): if ids in self._added_tokens_decoder: @@ -1058,20 +1077,24 @@ def convert_ids_to_tokens( def _convert_id_to_token(self, index: int) -> str: raise NotImplementedError - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: return " ".join(tokens) def _decode( self, - token_ids: List[int], + token_ids: Union[int, list[int]], skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool = None, + clean_up_tokenization_spaces: Optional[bool] = None, spaces_between_special_tokens: bool = True, **kwargs, ) -> str: self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + # If given is a single id, prevents splitting the string in upcoming loop + if isinstance(filtered_tokens, str): + filtered_tokens = [filtered_tokens] + legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | { token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size } @@ -1082,7 +1105,7 @@ def _decode( current_sub_text = [] # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string for token in filtered_tokens: - if skip_special_tokens and token in self.all_special_ids: + if skip_special_tokens and token in self.all_special_tokens: continue if token in legacy_added_tokens: if current_sub_text: From 512097722176afc0932a8c72e0c83a289bd77a89 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:34:25 +0800 Subject: [PATCH 29/94] add_model_info_to_custom_pipelines --- mindone/transformers/utils/generic.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index ed16b3095c..0597f24e0f 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -532,19 +532,6 @@ def tensor_size(array): raise ValueError(f"Type not supported for tensor_size: {type(array)}.") -# TODO: remove this function in v4.54.1 -def add_model_info_to_custom_pipelines(custom_pipeline, repo_id): - """ - Adds the information of the repo_id to a given custom pipeline. - """ - # {custom_pipelines : {task: {"impl": "path.to.task"},...} } - for task in custom_pipeline.keys(): - if "impl" in custom_pipeline[task]: - module = custom_pipeline[task]["impl"] - if "--" not in module: - custom_pipeline[task]["impl"] = f"{repo_id}--{module}" - return custom_pipeline - def infer_framework(model_class): """ From 9a1865591ac4359dff43d150d78ec92ea1665866 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:36:53 +0800 Subject: [PATCH 30/94] update tokenization_utils_base.py --- .../transformers/tokenization_utils_base.py | 1385 ++++++++--------- 1 file changed, 663 insertions(+), 722 deletions(-) diff --git a/mindone/transformers/tokenization_utils_base.py b/mindone/transformers/tokenization_utils_base.py index bfd2142754..78788edf6b 100644 --- a/mindone/transformers/tokenization_utils_base.py +++ b/mindone/transformers/tokenization_utils_base.py @@ -27,36 +27,40 @@ import re import warnings from collections import UserDict -from collections.abc import Mapping, Sized +from collections.abc import Mapping, Sequence, Sized from contextlib import contextmanager from dataclasses import dataclass -from functools import lru_cache -from inspect import isfunction -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union import numpy as np from packaging import version from . import __version__ -from .dynamic_module_utils import custom_object_save +from transformers.dynamic_module_utils import custom_object_save from .utils import ( ExplicitEnum, PaddingStrategy, - PushToHubMixin, TensorType, + is_numpy_array, + requires_backends, + to_py_obj, +) + +from transformers.utils import ( + CHAT_TEMPLATE_DIR, + CHAT_TEMPLATE_FILE, + PushToHubMixin, add_end_docstrings, - add_model_info_to_auto_map, - add_model_info_to_custom_pipelines, cached_file, copy_func, download_url, extract_commit_hash, - get_json_schema, is_flax_available, is_jax_tensor, is_mlx_available, - is_numpy_array, is_offline_mode, + is_protobuf_available, is_remote_url, is_tf_available, is_tf_tensor, @@ -64,10 +68,13 @@ is_torch_available, is_torch_device, is_torch_tensor, + list_repo_templates, logging, - requires_backends, - to_py_obj, + ) +from transformers.utils.chat_template_utils import render_jinja_template +from transformers.utils.import_utils import PROTOBUF_IMPORT_ERROR + if TYPE_CHECKING: if is_torch_available(): @@ -77,6 +84,16 @@ if is_flax_available(): import jax.numpy as jnp # noqa: F401 + +def import_protobuf_decode_error(error_message=""): + if is_protobuf_available(): + from google.protobuf.message import DecodeError + + return DecodeError + else: + raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message)) + + if is_tokenizers_available(): from tokenizers import AddedToken from tokenizers import Encoding as EncodingFast @@ -92,7 +109,9 @@ class AddedToken: `tokenizers`. """ - def __init__(self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None): + def __init__( + self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None + ): self.content = content self.single_word = single_word self.lstrip = lstrip @@ -120,14 +139,14 @@ class EncodingFast: # Define type aliases and NamedTuples TextInput = str -PreTokenizedInput = List[str] -EncodedInput = List[int] -TextInputPair = Tuple[str, str] -PreTokenizedInputPair = Tuple[List[str], List[str]] -EncodedInputPair = Tuple[List[int], List[int]] +PreTokenizedInput = list[str] +EncodedInput = list[int] +TextInputPair = tuple[str, str] +PreTokenizedInputPair = tuple[list[str], list[str]] +EncodedInputPair = tuple[list[int], list[int]] # Define type aliases for text-related non-text modalities -AudioInput = Union["np.ndarray", "torch.Tensor", List["np.ndarray"], List["torch.Tensor"]] +AudioInput = Union["np.ndarray", "torch.Tensor", list["np.ndarray"], list["torch.Tensor"]] # Slow tokenizers used to be saved in three separated files SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" @@ -198,7 +217,8 @@ class BatchEncoding(UserDict): You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at initialization. prepend_batch_axis (`bool`, *optional*, defaults to `False`): - Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). + Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). Note that this + parameter has an effect if the parameter `tensor_type` is set, *otherwise has no effect*. n_sequences (`Optional[int]`, *optional*): You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at initialization. @@ -206,7 +226,7 @@ class BatchEncoding(UserDict): def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None, tensor_type: Union[None, str, TensorType] = None, prepend_batch_axis: bool = False, @@ -219,7 +239,7 @@ def __init__( self._encodings = encoding - if n_sequences is None and encoding is not None and len(encoding): + if n_sequences is None and encoding is not None and encoding: n_sequences = encoding[0].n_sequences self._n_sequences = n_sequences @@ -281,28 +301,19 @@ def __setstate__(self, state): if "encodings" in state: self._encodings = state["encodings"] - def keys(self): - return self.data.keys() - - def values(self): - return self.data.values() - - def items(self): - return self.data.items() - # After this point: # Extended properties and methods only available for fast (Rust-based) tokenizers # provided by HuggingFace tokenizers library. @property - def encodings(self) -> Optional[List[EncodingFast]]: + def encodings(self) -> Optional[list[EncodingFast]]: """ - `Optional[List[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns `None` if + `Optional[list[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns `None` if the input was tokenized through Python (i.e., not a fast) tokenizer. """ return self._encodings - def tokens(self, batch_index: int = 0) -> List[str]: + def tokens(self, batch_index: int = 0) -> list[str]: """ Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to integer indices) at a given batch index (only works for the output of a fast tokenizer). @@ -311,7 +322,7 @@ def tokens(self, batch_index: int = 0) -> List[str]: batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. Returns: - `List[str]`: The list of tokens at that index. + `list[str]`: The list of tokens at that index. """ if not self._encodings: raise ValueError( @@ -320,7 +331,7 @@ def tokens(self, batch_index: int = 0) -> List[str]: ) return self._encodings[batch_index].tokens - def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]: + def sequence_ids(self, batch_index: int = 0) -> list[Optional[int]]: """ Return a list mapping the tokens to the id of their original sentences: @@ -333,7 +344,7 @@ def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]: batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. Returns: - `List[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added + `list[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding sequence. """ @@ -344,7 +355,7 @@ def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]: ) return self._encodings[batch_index].sequence_ids - def words(self, batch_index: int = 0) -> List[Optional[int]]: + def words(self, batch_index: int = 0) -> list[Optional[int]]: """ Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. @@ -352,7 +363,7 @@ def words(self, batch_index: int = 0) -> List[Optional[int]]: batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. Returns: - `List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the + `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word (several tokens will be mapped to the same word index if they are parts of that word). """ @@ -368,7 +379,7 @@ def words(self, batch_index: int = 0) -> List[Optional[int]]: ) return self.word_ids(batch_index) - def word_ids(self, batch_index: int = 0) -> List[Optional[int]]: + def word_ids(self, batch_index: int = 0) -> list[Optional[int]]: """ Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. @@ -376,7 +387,7 @@ def word_ids(self, batch_index: int = 0) -> List[Optional[int]]: batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. Returns: - `List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the + `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word (several tokens will be mapped to the same word index if they are parts of that word). """ @@ -517,7 +528,7 @@ def word_to_tokens( span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index) return TokenSpan(*span) if span is not None else None - def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: + def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> Optional[CharSpan]: """ Get the character span corresponding to an encoded token in a sequence of the batch. @@ -556,7 +567,9 @@ def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = return CharSpan(*span_indices) if span_indices is not None else None - def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int: + def char_to_token( + self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0 + ) -> int: """ Get the index of the token in the encoded output comprising a character in the original string for a sequence of the batch. @@ -583,7 +596,8 @@ def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = No Returns: - `int`: Index of the token. + `int`: Index of the token, or None if the char index refers to a whitespace only token and whitespace is + trimmed with `trim_offsets=True`. """ if not self._encodings: @@ -623,7 +637,7 @@ def word_to_chars( or 1) the provided word index belongs to. Returns: - `CharSpan` or `List[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan + `CharSpan` or `list[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan are NamedTuple with: - start: index of the first character associated to the token in the original string @@ -667,7 +681,7 @@ def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = Non Returns: - `int` or `List[int]`: Index or indices of the associated encoded token(s). + `int` or `list[int]`: Index or indices of the associated encoded token(s). """ if not self._encodings: @@ -702,7 +716,9 @@ def convert_to_tensors( # Get a function reference for the correct framework if tensor_type == TensorType.TENSORFLOW: if not is_tf_available(): - raise ImportError("Unable to convert output to TensorFlow tensors format, TensorFlow is not installed.") + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." + ) import tensorflow as tf as_tensor = tf.constant @@ -716,7 +732,7 @@ def convert_to_tensors( def as_tensor(value, dtype=None): if isinstance(value, list) and isinstance(value[0], np.ndarray): - return torch.tensor(np.array(value)) + return torch.from_numpy(np.array(value)) return torch.tensor(value) elif tensor_type == TensorType.JAX: @@ -736,7 +752,6 @@ def as_tensor(value, dtype=None): def is_tensor(obj): return isinstance(obj, mx.array) - else: def as_tensor(value, dtype=None): @@ -781,12 +796,13 @@ def as_tensor(value, dtype=None): return self - def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": + def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": """ - Send all values to device by calling `v.to(device)` (PyTorch only). + Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only). Args: device (`str` or `torch.device`): The device to put the tensors on. + non_blocking (`bool`): Whether to perform the copy asynchronously. Returns: [`BatchEncoding`]: The same instance after modification. @@ -797,7 +813,10 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) for k, v in self.data.items()} + self.data = { + k: v.to(device=device, non_blocking=non_blocking) if hasattr(v, "to") and callable(v.to) else v + for k, v in self.data.items() + } else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self @@ -843,16 +862,10 @@ class SpecialTokensMixin: ] def __init__(self, verbose=False, **kwargs): - self._bos_token = None - self._eos_token = None - self._unk_token = None - self._sep_token = None - self._pad_token = None - self._cls_token = None - self._mask_token = None self._pad_token_type_id = 0 - self._additional_special_tokens = [] self.verbose = verbose + self._special_tokens_map = dict.fromkeys(self.SPECIAL_TOKENS_ATTRIBUTES) + self._special_tokens_map["additional_special_tokens"] = [] # for BC where it defaults to empty list # We directly set the hidden value to allow initialization with special tokens # which are not yet in the vocabulary. Necessary for serialization/de-serialization @@ -864,9 +877,9 @@ def __init__(self, verbose=False, **kwargs): if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == "additional_special_tokens": assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple" - assert all( - isinstance(t, (str, AddedToken)) for t in value - ), "One of the tokens is not a string or an AddedToken" + assert all(isinstance(t, (str, AddedToken)) for t in value), ( + "One of the tokens is not a string or an AddedToken" + ) setattr(self, key, value) elif isinstance(value, (str, AddedToken)): setattr(self, key, value) @@ -882,7 +895,9 @@ def sanitize_special_tokens(self) -> int: return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) def add_special_tokens( - self, special_tokens_dict: Dict[str, Union[str, AddedToken]], replace_additional_special_tokens=True + self, + special_tokens_dict: dict[str, Union[str, AddedToken, Sequence[Union[str, AddedToken]]]], + replace_additional_special_tokens=True, ) -> int: """ Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If @@ -902,11 +917,11 @@ def add_special_tokens( makes it easy to develop model-agnostic training and fine-tuning scripts. When possible, special tokens are already registered for provided pretrained models (for instance - [`BertTokenizer`] `cls_token` is already registered to be :obj*'[CLS]'* and XLM's one is also registered to be + [`BertTokenizer`] `cls_token` is already registered to be `'[CLS]'` and XLM's one is also registered to be `''`). Args: - special_tokens_dict (dictionary *str* to *str* or `tokenizers.AddedToken`): + special_tokens_dict (dictionary *str* to *str*, `tokenizers.AddedToken`, or `Sequence[Union[str, AddedToken]]`): Keys should be in the list of predefined special attributes: [`bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`, `additional_special_tokens`]. @@ -914,7 +929,7 @@ def add_special_tokens( assign the index of the `unk_token` to them). replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`): If `True`, the existing list of additional special tokens will be replaced by the list provided in - `special_tokens_dict`. Otherwise, `self._additional_special_tokens` is just extended. In the former + `special_tokens_dict`. Otherwise, `self._special_tokens_map["additional_special_tokens"]` is just extended. In the former case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous @@ -950,9 +965,9 @@ def add_special_tokens( logger.info(f"Assigning {value} to the {key} key of the tokenizer") if key == "additional_special_tokens": - assert isinstance(value, (list, tuple)) and all( - isinstance(t, (str, AddedToken)) for t in value - ), f"Tokens {value} for key {key} should all be str or AddedToken instances" + assert isinstance(value, (list, tuple)) and all(isinstance(t, (str, AddedToken)) for t in value), ( + f"Tokens {value} for key {key} should all be str or AddedToken instances" + ) to_add = [] for token in value: @@ -965,7 +980,7 @@ def add_special_tokens( if replace_additional_special_tokens and len(to_add) > 0: setattr(self, key, list(to_add)) else: - self._additional_special_tokens.extend(to_add) + self._special_tokens_map["additional_special_tokens"].extend(to_add) added_tokens += to_add else: @@ -984,11 +999,11 @@ def add_special_tokens( return added_tokens def add_tokens( - self, new_tokens: Union[str, AddedToken, List[Union[str, AddedToken]]], special_tokens: bool = False + self, new_tokens: Union[str, AddedToken, Sequence[Union[str, AddedToken]]], special_tokens: bool = False ) -> int: """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to - it with indices starting from length of the current vocabulary and and will be isolated before the tokenization + it with indices starting from length of the current vocabulary and will be isolated before the tokenization algorithm is applied. Added tokens and tokens from the vocabulary of the tokenization algorithm are therefore not treated in the same way. @@ -998,7 +1013,7 @@ def add_tokens( In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method. Args: - new_tokens (`str`, `tokenizers.AddedToken` or a list of *str* or `tokenizers.AddedToken`): + new_tokens (`str`, `tokenizers.AddedToken` or a sequence of *str* or `tokenizers.AddedToken`): Tokens are only added if they are not already in the vocabulary. `tokenizers.AddedToken` wraps a string token to let you personalize its behavior: whether this token should only match against a single word, whether this token should strip all potential whitespaces on the left side, whether this token should @@ -1032,195 +1047,9 @@ def add_tokens( return self._add_tokens(new_tokens, special_tokens=special_tokens) - def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int: raise NotImplementedError - @property - def bos_token(self) -> str: - """ - `str`: Beginning of sentence token. Log an error if used while not having been set. - """ - if self._bos_token is None: - if self.verbose: - logger.error("Using bos_token, but it is not set yet.") - return None - return str(self._bos_token) - - @property - def eos_token(self) -> str: - """ - `str`: End of sentence token. Log an error if used while not having been set. - """ - if self._eos_token is None: - if self.verbose: - logger.error("Using eos_token, but it is not set yet.") - return None - return str(self._eos_token) - - @property - def unk_token(self) -> str: - """ - `str`: Unknown token. Log an error if used while not having been set. - """ - if self._unk_token is None: - if self.verbose: - logger.error("Using unk_token, but it is not set yet.") - return None - return str(self._unk_token) - - @property - def sep_token(self) -> str: - """ - `str`: Separation token, to separate context and query in an input sequence. Log an error if used while not - having been set. - """ - if self._sep_token is None: - if self.verbose: - logger.error("Using sep_token, but it is not set yet.") - return None - return str(self._sep_token) - - @property - def pad_token(self) -> str: - """ - `str`: Padding token. Log an error if used while not having been set. - """ - if self._pad_token is None: - if self.verbose: - logger.error("Using pad_token, but it is not set yet.") - return None - return str(self._pad_token) - - @property - def cls_token(self) -> str: - """ - `str`: Classification token, to extract a summary of an input sequence leveraging self-attention along the full - depth of the model. Log an error if used while not having been set. - """ - if self._cls_token is None: - if self.verbose: - logger.error("Using cls_token, but it is not set yet.") - return None - return str(self._cls_token) - - @property - def mask_token(self) -> str: - """ - `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not - having been set. - """ - if self._mask_token is None: - if self.verbose: - logger.error("Using mask_token, but it is not set yet.") - return None - return str(self._mask_token) - - @property - def additional_special_tokens(self) -> List[str]: - """ - `List[str]`: All the additional special tokens you may want to use. Log an error if used while not having been - set. - """ - if self._additional_special_tokens is None: - if self.verbose: - logger.error("Using additional_special_tokens, but it is not set yet.") - return None - return [str(tok) for tok in self._additional_special_tokens] - - @bos_token.setter - def bos_token(self, value): - if not isinstance(value, (str, AddedToken)) and value is not None: - raise ValueError("Cannot set a non-string value as the BOS token") - self._bos_token = value - - @eos_token.setter - def eos_token(self, value): - if not isinstance(value, (str, AddedToken)) and value is not None: - raise ValueError("Cannot set a non-string value as the EOS token") - self._eos_token = value - - @unk_token.setter - def unk_token(self, value): - if not isinstance(value, (str, AddedToken)) and value is not None: - raise ValueError("Cannot set a non-string value as the UNK token") - self._unk_token = value - - @sep_token.setter - def sep_token(self, value): - if not isinstance(value, (str, AddedToken)) and value is not None: - raise ValueError("Cannot set a non-string value as the SEP token") - self._sep_token = value - - @pad_token.setter - def pad_token(self, value): - if not isinstance(value, (str, AddedToken)) and value is not None: - raise ValueError("Cannot set a non-string value as the PAD token") - self._pad_token = value - - @cls_token.setter - def cls_token(self, value): - if not isinstance(value, (str, AddedToken)) and value is not None: - raise ValueError("Cannot set a non-string value as the CLS token") - self._cls_token = value - - @mask_token.setter - def mask_token(self, value): - if not isinstance(value, (str, AddedToken)) and value is not None: - raise ValueError("Cannot set a non-string value as the MASK token") - self._mask_token = value - - @additional_special_tokens.setter - def additional_special_tokens(self, value): - self._additional_special_tokens = value if value is not None else None - - @property - def bos_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the beginning of sentence token in the vocabulary. Returns `None` if the token has not - been set. - """ - if self._bos_token is None: - return None - return self.convert_tokens_to_ids(self.bos_token) - - @property - def eos_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been - set. - """ - if self._eos_token is None: - return None - return self.convert_tokens_to_ids(self.eos_token) - - @property - def unk_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the unknown token in the vocabulary. Returns `None` if the token has not been set. - """ - if self._unk_token is None: - return None - return self.convert_tokens_to_ids(self.unk_token) - - @property - def sep_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the separation token in the vocabulary, to separate context and query in an input - sequence. Returns `None` if the token has not been set. - """ - if self._sep_token is None: - return None - return self.convert_tokens_to_ids(self.sep_token) - - @property - def pad_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set. - """ - if self._pad_token is None: - return None - return self.convert_tokens_to_ids(self.pad_token) - @property def pad_token_type_id(self) -> int: """ @@ -1228,72 +1057,60 @@ def pad_token_type_id(self) -> int: """ return self._pad_token_type_id - @property - def cls_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the classification token in the vocabulary, to extract a summary of an input sequence - leveraging self-attention along the full depth of the model. - - Returns `None` if the token has not been set. - """ - if self._cls_token is None: - return None - return self.convert_tokens_to_ids(self.cls_token) - - @property - def mask_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the mask token in the vocabulary, used when training a model with masked-language - modeling. Returns `None` if the token has not been set. - """ - if self._mask_token is None: - return None - return self.convert_tokens_to_ids(self.mask_token) - - @property - def additional_special_tokens_ids(self) -> List[int]: - """ - `List[int]`: Ids of all the additional special tokens in the vocabulary. Log an error if used while not having - been set. - """ - return self.convert_tokens_to_ids(self.additional_special_tokens) - - @bos_token_id.setter - def bos_token_id(self, value): - self._bos_token = self.convert_ids_to_tokens(value) if value is not None else None - - @eos_token_id.setter - def eos_token_id(self, value): - self._eos_token = self.convert_ids_to_tokens(value) if value is not None else None + def __setattr__(self, key, value): + key_without_id = key + key_is_special_id = key.endswith("_id") or key.endswith("_ids") + if key_is_special_id: + key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4] - @unk_token_id.setter - def unk_token_id(self, value): - self._unk_token = self.convert_ids_to_tokens(value) if value is not None else None - - @sep_token_id.setter - def sep_token_id(self, value): - self._sep_token = self.convert_ids_to_tokens(value) if value is not None else None + if self.__dict__.get("_special_tokens_map", None) is not None and any( + name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id] + ): + if key_is_special_id: + if value is not None: + value = ( + self.convert_ids_to_tokens(value) + if key != "additional_special_tokens" + else [self.convert_ids_to_tokens(val) for val in value] + ) + key = key_without_id - @pad_token_id.setter - def pad_token_id(self, value): - self._pad_token = self.convert_ids_to_tokens(value) if value is not None else None + if key != "additional_special_tokens" and not isinstance(value, (str, AddedToken)) and value is not None: + raise ValueError(f"Cannot set a non-string value as the {key}") + self._special_tokens_map[key] = value + else: + super().__setattr__(key, value) - @cls_token_id.setter - def cls_token_id(self, value): - self._cls_token = self.convert_ids_to_tokens(value) if value is not None else None + def __getattr__(self, key): + key_without_id = key + key_is_special_id = key.endswith("_id") or key.endswith("_ids") + if key_is_special_id: + key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4] - @mask_token_id.setter - def mask_token_id(self, value): - self._mask_token = self.convert_ids_to_tokens(value) if value is not None else None + if self.__dict__.get("_special_tokens_map", None) is not None and any( + name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id] + ): + _special_tokens_map = self.__dict__["_special_tokens_map"] + if not key_is_special_id: + if _special_tokens_map[key] is None: + if self.verbose: + logger.error(f"Using {key}, but it is not set yet.") + return None + value = _special_tokens_map[key] + return str(value) if key != "additional_special_tokens" else [str(tok) for tok in value] + else: + attr_as_tokens = getattr(self, key_without_id) + return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None - @additional_special_tokens_ids.setter - def additional_special_tokens_ids(self, values): - self._additional_special_tokens = [self.convert_ids_to_tokens(value) for value in values] + if key not in self.__dict__: + raise AttributeError(f"{self.__class__.__name__} has no attribute {key}") + else: + return super().__getattr__(key) @property - def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]: + def special_tokens_map(self) -> dict[str, Union[str, list[str]]]: """ - `Dict[str, Union[str, List[str]]]`: A dictionary mapping special token class attributes (`cls_token`, + `dict[str, Union[str, list[str]]]`: A dictionary mapping special token class attributes (`cls_token`, `unk_token`, etc.) to their values (`''`, `''`, etc.). Convert potential tokens of `tokenizers.AddedToken` type to string. @@ -1306,9 +1123,9 @@ def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]: return set_attr @property - def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[Union[str, AddedToken]]]]: + def special_tokens_map_extended(self) -> dict[str, Union[str, AddedToken, list[Union[str, AddedToken]]]]: """ - `Dict[str, Union[str, tokenizers.AddedToken, List[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping + `dict[str, Union[str, tokenizers.AddedToken, list[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping special token class attributes (`cls_token`, `unk_token`, etc.) to their values (`''`, `''`, etc.). Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how @@ -1316,15 +1133,15 @@ def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[U """ set_attr = {} for attr in self.SPECIAL_TOKENS_ATTRIBUTES: - attr_value = getattr(self, "_" + attr) + attr_value = self._special_tokens_map[attr] if attr_value: set_attr[attr] = attr_value return set_attr @property - def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]: + def all_special_tokens_extended(self) -> list[Union[str, AddedToken]]: """ - `List[Union[str, tokenizers.AddedToken]]`: All the special tokens (`''`, `''`, etc.), the order has + `list[Union[str, tokenizers.AddedToken]]`: All the special tokens (`''`, `''`, etc.), the order has nothing to do with the index of each tokens. If you want to know the correct indices, check `self.added_tokens_encoder`. We can't create an order anymore as the keys are `AddedTokens` and not `Strings`. @@ -1343,9 +1160,9 @@ def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]: return all_tokens @property - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: """ - `List[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.). + `list[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.). Convert tokens of `tokenizers.AddedToken` type to string. """ @@ -1353,26 +1170,40 @@ def all_special_tokens(self) -> List[str]: return all_toks @property - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: """ - `List[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. + `list[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. """ all_toks = self.all_special_tokens all_ids = self.convert_tokens_to_ids(all_toks) return all_ids + def _set_model_specific_special_tokens(self, special_tokens: list[str]): + """ + Adds new special tokens to the "SPECIAL_TOKENS_ATTRIBUTES" list which will be part + of "self.special_tokens" and saved as a special token in tokenizer's config. + This allows us to dynamically add new model-type specific tokens after initializing the tokenizer. + For example: if the model tokenizers is multimodal, we can support special image or audio tokens. + """ + self.SPECIAL_TOKENS_ATTRIBUTES = self.SPECIAL_TOKENS_ATTRIBUTES + list(special_tokens.keys()) + for key, value in special_tokens.items(): + if isinstance(value, (str, AddedToken)): + self._special_tokens_map[key] = value + else: + raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}") + ENCODE_KWARGS_DOCSTRING = r""" add_special_tokens (`bool`, *optional*, defaults to `True`): Whether or not to add special tokens when encoding the sequences. This will use the underlying `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are - automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens + automatically added to the input ids. This is useful if you want to add `bos` or `eos` tokens automatically. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): Activates and controls padding. Accepts the following values: - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). + sequence is provided). - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different @@ -1411,6 +1242,9 @@ def all_special_ids(self) -> List[int]: If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: @@ -1477,14 +1311,14 @@ def all_special_ids(self) -> List[int]: INIT_TOKENIZER_DOCSTRING = r""" Class attributes (overridden by derived classes) - - **vocab_files_names** (`Dict[str, str]`) -- A dictionary with, as keys, the `__init__` keyword name of each + - **vocab_files_names** (`dict[str, str]`) -- A dictionary with, as keys, the `__init__` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). - - **pretrained_vocab_files_map** (`Dict[str, Dict[str, str]]`) -- A dictionary of dictionaries, with the + - **pretrained_vocab_files_map** (`dict[str, dict[str, str]]`) -- A dictionary of dictionaries, with the high-level keys being the `__init__` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` of the pretrained models with, as associated values, the `url` to the associated pretrained vocabulary file. - - **model_input_names** (`List[str]`) -- A list of inputs expected in the forward pass of the model. + - **model_input_names** (`list[str]`) -- A list of inputs expected in the forward pass of the model. - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied. Should be `'right'` or `'left'`. - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation @@ -1505,7 +1339,7 @@ def all_special_ids(self) -> List[int]: chat_template (`str`, *optional*): A Jinja template string that will be used to format lists of chat messages. See https://huggingface.co/docs/transformers/chat_templating for a full description. - model_input_names (`List[string]`, *optional*): + model_input_names (`list[string]`, *optional*): The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or `"attention_mask"`). Default value is picked from the class attribute of the same name. bos_token (`str` or `tokenizers.AddedToken`, *optional*): @@ -1552,13 +1386,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): Handles shared (mostly boiler plate) methods for those two classes. """ - vocab_files_names: Dict[str, str] = {} - pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {} + vocab_files_names: dict[str, str] = {} + pretrained_vocab_files_map: dict[str, dict[str, str]] = {} _auto_class: Optional[str] = None # first name has to correspond to main model input name # to make sure `tokenizer.pad(...)` works correctly - model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"] + model_input_names: list[str] = ["input_ids", "token_type_ids", "attention_mask"] padding_side: str = "right" truncation_side: str = "right" slow_tokenizer_class = None @@ -1566,6 +1400,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): def __init__(self, **kwargs): # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) self.init_inputs = () + for key in kwargs: + if hasattr(self, key) and callable(getattr(self, key)): + raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") + self.init_kwargs = copy.deepcopy(kwargs) self.name_or_path = kwargs.pop("name_or_path", "") self._processor_class = kwargs.pop("processor_class", None) @@ -1591,14 +1429,12 @@ def __init__(self, **kwargs): self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) # By default, cleaning tokenization spaces for both fast and slow tokenizers - self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True) + self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False) # By default, do not split special tokens for both fast and slow tokenizers self.split_special_tokens = kwargs.pop("split_special_tokens", False) - self.deprecation_warnings = ( - {} - ) # Use to store when we have already noticed a deprecation warning (avoid overlogging). + self.deprecation_warnings = {} # Use to store when we have already noticed a deprecation warning (avoid overlogging). self._in_target_context_manager = False # Stores a Jinja template that formats chat histories into tokenizable strings @@ -1610,6 +1446,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) + self.extra_special_tokens = kwargs.pop("extra_special_tokens", {}) + self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens) + @property def max_len_single_sentence(self) -> int: """ @@ -1634,7 +1473,9 @@ def max_len_single_sentence(self, value) -> int: ) self.deprecation_warnings["max_len_single_sentence"] = True else: - raise ValueError("Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up.") + raise ValueError( + "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up." + ) @max_len_sentences_pair.setter def max_len_sentences_pair(self, value) -> int: @@ -1653,7 +1494,7 @@ def _set_processor_class(self, processor_class: str): self._processor_class = processor_class @property - def added_tokens_decoder(self) -> Dict[int, AddedToken]: + def added_tokens_decoder(self) -> dict[int, AddedToken]: raise NotImplementedError() def __repr__(self) -> str: @@ -1662,14 +1503,14 @@ def __repr__(self) -> str: f"{self.__class__.__name__}(name_or_path='{self.name_or_path}'," f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast}," f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}'," - f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces}), " - " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}" + f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces}," + " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)" ) def __len__(self) -> int: raise NotImplementedError() - def get_vocab(self) -> Dict[str, int]: + def get_vocab(self) -> dict[str, int]: """ Returns the vocabulary as a dictionary of token to index. @@ -1677,42 +1518,43 @@ def get_vocab(self) -> Dict[str, int]: vocab. Returns: - `Dict[str, int]`: The vocabulary. + `dict[str, int]`: The vocabulary. """ raise NotImplementedError() def apply_chat_template( self, - conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], - tools: Optional[List[Dict]] = None, - documents: Optional[List[Dict[str, str]]] = None, + conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], + tools: Optional[list[Union[dict, Callable]]] = None, + documents: Optional[list[dict[str, str]]] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = False, + continue_final_message: bool = False, tokenize: bool = True, - padding: bool = False, + padding: Union[bool, str, PaddingStrategy] = False, truncation: bool = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_dict: bool = False, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, + return_assistant_tokens_mask: bool = False, + tokenizer_kwargs: Optional[dict[str, Any]] = None, **kwargs, - ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: + ) -> Union[str, list[int], list[str], list[list[int]], BatchEncoding]: """ Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to - determine the format and control tokens to use when converting. When chat_template is None, it will fall back - to the default_chat_template specified at the class level. + determine the format and control tokens to use when converting. Args: - conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts + conversation (Union[list[dict[str, str]], list[list[dict[str, str]]]]): A list of dicts with "role" and "content" keys, representing the chat history so far. - tools (`List[Dict]`, *optional*): + tools (`list[Union[Dict, Callable]]`, *optional*): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, giving the name, description and argument types for the tool. See our [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) for more information. - documents (`List[Dict[str, str]]`, *optional*): + documents (`list[dict[str, str]]`, *optional*): A list of dicts representing documents that will be accessible to the model if it is performing RAG (retrieval-augmented generation). If the template does not support RAG, this argument will have no effect. We recommend that each document should be a dict containing "title" and "text" keys. Please @@ -1721,14 +1563,28 @@ def apply_chat_template( chat_template (`str`, *optional*): A Jinja template to use for this conversion. It is usually not necessary to pass anything to this argument, as the model's template will be used by default. - add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate - the start of an assistant message. This is useful when you want to generate a response from the model. + add_generation_prompt (bool, *optional*): + If this is set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the template for this argument to have any effect. + continue_final_message (bool, *optional*): + If this is set, the chat will be formatted so that the final + message in the chat is open-ended, without any EOS tokens. The model will continue this message + rather than starting a new one. This allows you to "prefill" part of + the model's response for it. Cannot be used at the same time as `add_generation_prompt`. tokenize (`bool`, defaults to `True`): Whether to tokenize the output. If `False`, the output will be a string. - padding (`bool`, defaults to `False`): - Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). truncation (`bool`, defaults to `False`): Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. max_length (`int`, *optional*): @@ -1743,11 +1599,15 @@ def apply_chat_template( - `'jax'`: Return JAX `jnp.ndarray` objects. return_dict (`bool`, defaults to `False`): Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. - tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + tokenizer_kwargs (`dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + return_assistant_tokens_mask (`bool`, defaults to `False`): + Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, + the mask will contain 1. For user and system tokens, the mask will contain 0. + This functionality is only available for chat templates that support it via the `{% generation %}` keyword. **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. Returns: - `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This + `Union[list[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is set, will return a dict of tokenizer outputs instead. """ @@ -1758,60 +1618,13 @@ def apply_chat_template( "of tokenizer outputs to return." ) + if return_assistant_tokens_mask and not return_dict: + raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`") + if tokenizer_kwargs is None: tokenizer_kwargs = {} - using_default_template = False - - # First, handle the cases when the model has a dict of multiple templates - if isinstance(self.chat_template, dict) or ( - self.chat_template is None and isinstance(self.default_chat_template, dict) - ): - if self.chat_template is not None: - template_dict = self.chat_template - using_default_dict = False - else: - template_dict = self.default_chat_template - using_default_dict = True - if chat_template is not None and chat_template in template_dict: - # The user can pass the name of a template to the chat template argument instead of an entire template - chat_template = template_dict[chat_template] - if using_default_dict: - using_default_template = True - elif chat_template is None: - if tools is not None and "tool_use" in template_dict: - chat_template = template_dict["tool_use"] - elif "default" in template_dict: - chat_template = template_dict["default"] - else: - raise ValueError( - "This model has multiple chat templates with no default specified! Please either pass a chat " - "template or the name of the template you wish to use to the `chat_template` argument. Available " - f"template names are {sorted(template_dict.keys())}." - ) - if using_default_dict: - using_default_template = True - - elif chat_template is None: - # These are the cases when the model has a single template - # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template - if self.chat_template is not None: - chat_template = self.chat_template - else: - chat_template = self.default_chat_template - using_default_template = True - - if using_default_template: - logger.warning_once( - "No chat template is set for this tokenizer, falling back to a default class-level template. This is " - "very error-prone, because models are often trained with templates different from the class default! " - "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " - "point any code depending on them will stop working. We recommend setting a valid chat template before " - "then to ensure that this model continues working without issues." - ) - - # Compilation function uses a cache to avoid recompiling the same template - compiled_template = self._compile_jinja_template(chat_template) + chat_template = self.get_chat_template(chat_template, tools) if isinstance(conversation, (list, tuple)) and ( isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") @@ -1822,48 +1635,32 @@ def apply_chat_template( conversations = [conversation] is_batched = False - # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas - if tools is not None: - tool_schemas = [] - for tool in tools: - if isinstance(tool, dict): - tool_schemas.append(tool) - elif isfunction(tool): - tool_schemas.append(get_json_schema(tool)) - else: - raise ValueError( - "Tools should either be a JSON schema, or a callable function with type hints " - "and a docstring suitable for auto-conversion to a schema." - ) - else: - tool_schemas = None - - if documents is not None: - for document in documents: - if not isinstance(document, dict): - raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!") + if continue_final_message: + if add_generation_prompt: + raise ValueError( + "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." + ) + if return_assistant_tokens_mask: + raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") - rendered = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present - for chat in conversations: - if hasattr(chat, "messages"): - # Indicates it's a Conversation object - chat = chat.messages - rendered_chat = compiled_template.render( - messages=chat, - tools=tool_schemas, - documents=documents, - add_generation_prompt=add_generation_prompt, - **template_kwargs, - ) - rendered.append(rendered_chat) + rendered_chat, generation_indices = render_jinja_template( + conversations=conversations, + tools=tools, + documents=documents, + chat_template=chat_template, + return_assistant_tokens_mask=return_assistant_tokens_mask, + continue_final_message=continue_final_message, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ) if not is_batched: - rendered = rendered[0] + rendered_chat = rendered_chat[0] if tokenize: out = self( - rendered, + rendered_chat, padding=padding, truncation=truncation, max_length=max_length, @@ -1872,53 +1669,91 @@ def apply_chat_template( **tokenizer_kwargs, ) if return_dict: + if return_assistant_tokens_mask: + assistant_masks = [] + if is_batched or return_tensors: + input_ids = out["input_ids"] + else: + input_ids = [out["input_ids"]] + for i in range(len(input_ids)): + current_mask = [0] * len(input_ids[i]) + for assistant_start_char, assistant_end_char in generation_indices[i]: + start_token = out.char_to_token(i, assistant_start_char) + end_token = out.char_to_token(i, assistant_end_char - 1) + if start_token is None: + # start_token is out of bounds maybe due to truncation. + break + for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): + current_mask[token_id] = 1 + assistant_masks.append(current_mask) + + if not is_batched and not return_tensors: + assistant_masks = assistant_masks[0] + + out["assistant_masks"] = assistant_masks + + if return_tensors: + out.convert_to_tensors(tensor_type=return_tensors) + return out else: return out["input_ids"] else: - return rendered + return rendered_chat - @lru_cache - def _compile_jinja_template(self, chat_template): - try: - import jinja2 - from jinja2.exceptions import TemplateError - from jinja2.sandbox import ImmutableSandboxedEnvironment - except ImportError: - raise ImportError("apply_chat_template requires jinja2 to be installed.") - - if version.parse(jinja2.__version__) < version.parse("3.1.0"): - raise ImportError( - "apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}." - ) + def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[list[dict]] = None) -> str: + """ + Retrieve the chat template string used for tokenizing chat messages. This template is used + internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat + template for better generation tracking. - def raise_exception(message): - raise TemplateError(message) + Args: + chat_template (`str`, *optional*): + A Jinja template or the name of a template to use for this conversion. + It is usually not necessary to pass anything to this argument, + as the model's template will be used by default. + tools (`list[Dict]`, *optional*): + A list of tools (callable functions) that will be accessible to the model. If the template does not + support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + giving the name, description and argument types for the tool. See our + [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + for more information. - def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): - # We override the built-in tojson filter because Jinja's default filter escapes HTML characters - # We also expose some options like custom indents and separators - return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) + Returns: + `str`: The chat template string. + """ + # First, handle the cases when the model has a dict of multiple templates + if isinstance(self.chat_template, dict): + template_dict = self.chat_template + if chat_template is not None and chat_template in template_dict: + # The user can pass the name of a template to the chat template argument instead of an entire template + chat_template = template_dict[chat_template] + elif chat_template is None: + if tools is not None and "tool_use" in template_dict: + chat_template = template_dict["tool_use"] + elif "default" in template_dict: + chat_template = template_dict["default"] + else: + raise ValueError( + "This model has multiple chat templates with no default specified! Please either pass a chat " + "template or the name of the template you wish to use to the `chat_template` argument. Available " + f"template names are {sorted(template_dict.keys())}." + ) - jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) - jinja_env.filters["tojson"] = tojson - jinja_env.globals["raise_exception"] = raise_exception - return jinja_env.from_string(chat_template) + elif chat_template is None: + # These are the cases when the model has a single template + # priority: `chat_template` argument > `tokenizer.chat_template` + if self.chat_template is not None: + chat_template = self.chat_template + else: + raise ValueError( + "Cannot use chat template functions because tokenizer.chat_template is not set and no template " + "argument was passed! For information about writing templates and setting the " + "tokenizer.chat_template attribute, please see the documentation at " + "https://huggingface.co/docs/transformers/main/en/chat_templating" + ) - @property - def default_chat_template(self): - """ - This template formats inputs in the standard ChatML format. See - https://github.com/openai/openai-python/blob/main/chatml.md - """ - return ( - "{% for message in messages %}" - "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{ '<|im_start|>assistant\n' }}" - "{% endif %}" - ) + return chat_template @classmethod def from_pretrained( @@ -1957,12 +1792,12 @@ def from_pretrained( resume_download: Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers. - proxies (`Dict[str, str]`, *optional*): + proxies (`dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). + when running `hf auth login` (stored in `~/.huggingface`). local_files_only (`bool`, *optional*, defaults to `False`): Whether or not to only rely on local files and not to attempt to download any files. revision (`str`, *optional*, defaults to `"main"`): @@ -2071,28 +1906,42 @@ def from_pretrained( "tokenizer_config_file": TOKENIZER_CONFIG_FILE, # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders "tokenizer_file": FULL_TOKENIZER_FILE, + "chat_template_file": CHAT_TEMPLATE_FILE, } + vocab_files = {**cls.vocab_files_names, **additional_files_names} if "tokenizer_file" in vocab_files: # Try to get the tokenizer config to see if there are versioned tokenizer files. fast_tokenizer_file = FULL_TOKENIZER_FILE - resolved_config_file = cached_file( - pretrained_model_name_or_path, - TOKENIZER_CONFIG_FILE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - subfolder=subfolder, - user_agent=user_agent, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - _commit_hash=commit_hash, - ) + + try: + resolved_config_file = cached_file( + pretrained_model_name_or_path, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + user_agent=user_agent, + _raise_exceptions_for_missing_entries=False, + _commit_hash=commit_hash, + ) + except OSError: + # Re-raise any error raised by cached_file in order to get a helpful error message + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing all relevant files for a {cls.__name__} tokenizer." + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) if resolved_config_file is not None: with open(resolved_config_file, encoding="utf-8") as reader: @@ -2101,9 +1950,26 @@ def from_pretrained( fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) vocab_files["tokenizer_file"] = fast_tokenizer_file + # This block looks for any extra chat template files + if is_local: + template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR) + if template_dir.is_dir(): + for template_file in template_dir.glob("*.jinja"): + template_name = template_file.name.removesuffix(".jinja") + vocab_files[f"chat_template_{template_name}"] = ( + f"{CHAT_TEMPLATE_DIR}/{template_file.name}" + ) + else: + for template in list_repo_templates( + pretrained_model_name_or_path, + local_files_only=local_files_only, + revision=revision, + cache_dir=cache_dir, + ): + vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja" + # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} - unresolved_files = [] for file_id, file_path in vocab_files.items(): if file_path is None: resolved_vocab_files[file_id] = None @@ -2113,41 +1979,35 @@ def from_pretrained( elif is_remote_url(file_path): resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies) else: - resolved_vocab_files[file_id] = cached_file( - pretrained_model_name_or_path, - file_path, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - _commit_hash=commit_hash, - ) + try: + resolved_vocab_files[file_id] = cached_file( + pretrained_model_name_or_path, + file_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _commit_hash=commit_hash, + ) + except OSError: + # Re-raise any error raised by cached_file in order to get a helpful error message + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing all relevant files for a {cls.__name__} tokenizer." + ) commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash) - if len(unresolved_files) > 0: - logger.info( - f"Can't load following files from cache: {unresolved_files} and cannot check if these " - "files are necessary for the tokenizer to operate." - ) - - # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be - # loaded directly from the GGUF file. - if all(full_file_name is None for full_file_name in resolved_vocab_files.values()) and not gguf_file: - raise EnvironmentError( - f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing all relevant files for a {cls.__name__} tokenizer." - ) - for file_id, file_path in vocab_files.items(): if file_id not in resolved_vocab_files: continue @@ -2227,18 +2087,30 @@ def _from_pretrained( config_tokenizer_class = None init_kwargs = init_configuration + # If independent chat template file(s) exist, they take priority over template entries in the tokenizer config + chat_templates = {} + chat_template_file = resolved_vocab_files.pop("chat_template_file", None) + extra_chat_templates = [key for key in resolved_vocab_files if key.startswith("chat_template_")] + if chat_template_file is not None: + with open(chat_template_file, encoding="utf-8") as chat_template_handle: + chat_templates["default"] = chat_template_handle.read() + for extra_chat_template in extra_chat_templates: + template_file = resolved_vocab_files.pop(extra_chat_template, None) + if template_file is None: + continue # I think this should never happen, but just in case + template_name = extra_chat_template.removeprefix("chat_template_") + with open(template_file) as chat_template_handle: + chat_templates[template_name] = chat_template_handle.read() + if len(chat_templates) == 1 and "default" in chat_templates: + init_kwargs["chat_template"] = chat_templates["default"] + elif chat_templates: + init_kwargs["chat_template"] = chat_templates + if not _is_local: if "auto_map" in init_kwargs: # For backward compatibility with odl format. if isinstance(init_kwargs["auto_map"], (tuple, list)): init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]} - init_kwargs["auto_map"] = add_model_info_to_auto_map( - init_kwargs["auto_map"], pretrained_model_name_or_path - ) - if "custom_pipelines" in init_kwargs: - init_kwargs["custom_pipelines"] = add_model_info_to_custom_pipelines( - init_kwargs["custom_pipelines"], pretrained_model_name_or_path - ) if config_tokenizer_class is None: # Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo. @@ -2307,9 +2179,9 @@ def _from_pretrained( init_kwargs["__slow_tokenizer"] = slow_tokenizer init_kwargs["name_or_path"] = pretrained_model_name_or_path - # Handle tokenizer serialization of added and special tokens - added_tokens_decoder: Dict[int, AddedToken] = {} - added_tokens_map: Dict[str, AddedToken] = {} + #### Handle tokenizer serialization of added and special tokens + added_tokens_decoder: dict[int, AddedToken] = {} + added_tokens_map: dict[str, AddedToken] = {} # if we have info on the slow added tokens if "added_tokens_decoder" in init_kwargs: for idx, token in init_kwargs["added_tokens_decoder"].items(): @@ -2319,7 +2191,7 @@ def _from_pretrained( added_tokens_decoder[int(idx)] = token added_tokens_map[str(token)] = token else: - raise ValueError( + raise TypeError( f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance" ) else: @@ -2392,6 +2264,19 @@ def _from_pretrained( # Instantiate the tokenizer. try: tokenizer = cls(*init_inputs, **init_kwargs) + except import_protobuf_decode_error(): + logger.info( + "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." + "(Google protobuf error: Tried to load SPM model with non-SPM vocab file).", + ) + return False + except RuntimeError as e: + if "sentencepiece_processor.cc" in str(e): + logger.info( + "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." + "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).", + ) + return False except OSError: raise OSError( "Unable to load vocabulary from file. " @@ -2399,7 +2284,7 @@ def _from_pretrained( ) if added_tokens_decoder != {} and max(list(added_tokens_decoder.keys())[-1], 0) > tokenizer.vocab_size: - logger.warning_advice( + logger.info( "Special tokens have been added in the vocabulary, make sure the associated word embeddings are" " fine-tuned or trained." ) @@ -2432,6 +2317,61 @@ def convert_added_tokens(cls, obj: Union[AddedToken, Any], save=False, add_type_ return {k: cls.convert_added_tokens(v, save=save, add_type_field=add_type_field) for k, v in obj.items()} return obj + def save_chat_templates( + self, + save_directory: Union[str, os.PathLike], + tokenizer_config: dict, + filename_prefix: Optional[str], + save_jinja_files: bool, + ): + """ + Writes chat templates out to the save directory if we're using the new format, and removes them from + the tokenizer config if present. If we're using the legacy format, it doesn't write any files, and instead + writes the templates to the tokenizer config in the correct format. + """ + chat_template_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE + ) + chat_template_dir = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_DIR + ) + + saved_raw_chat_template_files = [] + if save_jinja_files and isinstance(self.chat_template, str): + # New format for single templates is to save them as chat_template.jinja + with open(chat_template_file, "w", encoding="utf-8") as f: + f.write(self.chat_template) + logger.info(f"chat template saved in {chat_template_file}") + saved_raw_chat_template_files.append(chat_template_file) + if "chat_template" in tokenizer_config: + tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too + elif save_jinja_files and isinstance(self.chat_template, dict): + # New format for multiple templates is to save the default as chat_template.jinja + # and the other templates in the chat_templates/ directory + for template_name, template in self.chat_template.items(): + if template_name == "default": + with open(chat_template_file, "w", encoding="utf-8") as f: + f.write(self.chat_template["default"]) + logger.info(f"chat template saved in {chat_template_file}") + saved_raw_chat_template_files.append(chat_template_file) + else: + Path(chat_template_dir).mkdir(exist_ok=True) + template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja") + with open(template_filepath, "w", encoding="utf-8") as f: + f.write(template) + logger.info(f"chat template saved in {template_filepath}") + saved_raw_chat_template_files.append(template_filepath) + if "chat_template" in tokenizer_config: + tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too + elif isinstance(self.chat_template, dict): + # Legacy format for multiple templates: + # chat template dicts are saved to the config as lists of dicts with fixed key names. + tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] + elif self.chat_template is not None: + # Legacy format for single templates: Just make them a key in tokenizer_config.json + tokenizer_config["chat_template"] = self.chat_template + return tokenizer_config, saved_raw_chat_template_files + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -2439,7 +2379,7 @@ def save_pretrained( filename_prefix: Optional[str] = None, push_to_hub: bool = False, **kwargs, - ) -> Tuple[str]: + ) -> tuple[str]: """ Save the full tokenizer state. @@ -2469,7 +2409,7 @@ def save_pretrained( Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - kwargs (`Dict[str, Any]`, *optional*): + kwargs (`dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Returns: @@ -2518,16 +2458,16 @@ def save_pretrained( if hasattr(self, k): tokenizer_config[k] = getattr(self, k) - # Let's make sure we properly save the special tokens. + # Let's make sure we properly save the special tokens tokenizer_config.update(self.special_tokens_map) + if "extra_special_tokens" not in tokenizer_config: + tokenizer_config["extra_special_tokens"] = self.extra_special_tokens + tokenizer_config.update(self.extra_special_tokens) - if self.chat_template is not None: - if isinstance(self.chat_template, dict): - # Chat template dicts are saved to the config as lists of dicts with fixed key names. - # They will be reconstructed as a single dict during loading. - tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] - else: - tokenizer_config["chat_template"] = self.chat_template + save_jinja_files = kwargs.get("save_jinja_files", True) + tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates( + save_directory, tokenizer_config, filename_prefix, save_jinja_files + ) if len(self.init_inputs) > 0: tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) @@ -2537,7 +2477,7 @@ def save_pretrained( # no typefields, this way old fast and slow can load it tokenizer_config = self.convert_added_tokens(tokenizer_config, add_type_field=True, save=True) - # Process added tokens seperatly: allows previous versions to ignore it! + # Process added tokens separately: allows previous versions to ignore it! added_tokens = {} for key, value in self.added_tokens_decoder.items(): added_tokens[key] = value.__getstate__() @@ -2545,8 +2485,8 @@ def save_pretrained( # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained tokenizer_class = self.__class__.__name__ - # Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast` - if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast": + # Remove the Fast at the end if we can save the slow tokenizer + if tokenizer_class.endswith("Fast") and getattr(self, "can_save_slow_tokenizer", False): tokenizer_class = tokenizer_class[:-4] tokenizer_config["tokenizer_class"] = tokenizer_class if getattr(self, "_auto_map", None) is not None: @@ -2564,6 +2504,8 @@ def save_pretrained( tokenizer_config.pop("name_or_path") tokenizer_config.pop("special_tokens_map_file", None) tokenizer_config.pop("tokenizer_file", None) + if "device_map" in tokenizer_config: + tokenizer_config.pop("device_map") with open(tokenizer_config_file, "w", encoding="utf-8") as f: out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n" @@ -2579,7 +2521,7 @@ def save_pretrained( f.write(out_str) logger.info(f"Special tokens file saved in {special_tokens_map_file}") - file_names = (tokenizer_config_file, special_tokens_map_file) + file_names = (tokenizer_config_file, special_tokens_map_file, *saved_raw_chat_template_files) save_files = self._save_pretrained( save_directory=save_directory, @@ -2602,10 +2544,10 @@ def save_pretrained( def _save_pretrained( self, save_directory: Union[str, os.PathLike], - file_names: Tuple[str], + file_names: tuple[str], legacy_format: Optional[bool] = None, filename_prefix: Optional[str] = None, - ) -> Tuple[str]: + ) -> tuple[str]: """ Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens. @@ -2634,7 +2576,7 @@ def _save_pretrained( return file_names + vocab_files + (added_tokens_file,) - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: """ Save only the vocabulary of the tokenizer (vocabulary + added tokens). @@ -2652,7 +2594,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = """ raise NotImplementedError - def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]: """ Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`. @@ -2668,7 +2610,7 @@ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bo [`~PreTrainedTokenizerBase.__call__`] Returns: - `List[str]`: The list of tokens. + `list[str]`: The list of tokens. """ raise NotImplementedError @@ -2679,7 +2621,7 @@ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bo """, """ Returns: - `List[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text. + `list[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text. """, ) def encode( @@ -2688,23 +2630,24 @@ def encode( text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, + truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, - ) -> List[int]: + ) -> list[int]: """ Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. Same as doing `self.convert_tokens_to_ids(self.tokenize(text))`. Args: - text (`str`, `List[str]` or `List[int]`): + text (`str`, `list[str]` or `list[int]`): The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` method). - text_pair (`str`, `List[str]` or `List[int]`, *optional*): + text_pair (`str`, `list[str]` or `list[int]`, *optional*): Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` method). @@ -2717,6 +2660,7 @@ def encode( truncation=truncation, max_length=max_length, stride=stride, + padding_side=padding_side, return_tensors=return_tensors, **kwargs, ) @@ -2730,11 +2674,8 @@ def _get_padding_truncation_strategies( self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs ): """ - Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy - and pad_to_max_length) and behaviors. + Find the correct padding/truncation strategy """ - old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate") - old_pad_to_max_length = kwargs.pop("pad_to_max_length", False) # Backward compatibility for previous behavior, maybe we should deprecate it: # If you only set max_length, it activates truncation for max_length @@ -2752,21 +2693,7 @@ def _get_padding_truncation_strategies( truncation = "longest_first" # Get padding strategy - if padding is False and old_pad_to_max_length: - if verbose: - warnings.warn( - "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " - "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " - "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " - "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the " - "maximal input size of the model (e.g. 512 for Bert).", - FutureWarning, - ) - if max_length is None: - padding_strategy = PaddingStrategy.LONGEST - else: - padding_strategy = PaddingStrategy.MAX_LENGTH - elif padding is not False: + if padding is not False: if padding is True: if verbose: if max_length is not None and ( @@ -2776,8 +2703,6 @@ def _get_padding_truncation_strategies( "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " "To pad to max length, use `padding='max_length'`." ) - if old_pad_to_max_length is not False: - warnings.warn("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.") padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch elif not isinstance(padding, PaddingStrategy): padding_strategy = PaddingStrategy(padding) @@ -2787,21 +2712,7 @@ def _get_padding_truncation_strategies( padding_strategy = PaddingStrategy.DO_NOT_PAD # Get truncation strategy - if truncation is None and old_truncation_strategy != "do_not_truncate": - if verbose: - warnings.warn( - "The `truncation_strategy` argument is deprecated and will be removed in a future version, use" - " `truncation=True` to truncate examples to a max length. You can give a specific length with" - " `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the maximal input" - " size of the model (e.g. 512 for Bert). If you have pairs of inputs, you can give a specific" - " truncation strategy selected among `truncation='only_first'` (will only truncate the first" - " sentence in the pairs) `truncation='only_second'` (will only truncate the second sentence in the" - " pairs) or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence" - " in the pairs).", - FutureWarning, - ) - truncation_strategy = TruncationStrategy(old_truncation_strategy) - elif truncation is not False and truncation is not None: + if truncation is not False and truncation is not None: if truncation is True: truncation_strategy = ( TruncationStrategy.LONGEST_FIRST @@ -2867,19 +2778,20 @@ def _get_padding_truncation_strategies( @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, - text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None, text_pair_target: Optional[ - Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] ] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, + truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -2895,19 +2807,19 @@ def __call__( sequences. Args: - text (`str`, `List[str]`, `List[List[str]]`, *optional*): + text (`str`, `list[str]`, `list[list[str]]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*): + text_pair (`str`, `list[str]`, `list[list[str]]`, *optional*): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - text_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + text_target (`str`, `list[str]`, `list[list[str]]`, *optional*): The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + text_pair_target (`str`, `list[str]`, `list[list[str]]`, *optional*): The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). @@ -2921,6 +2833,7 @@ def __call__( "stride": stride, "is_split_into_words": is_split_into_words, "pad_to_multiple_of": pad_to_multiple_of, + "padding_side": padding_side, "return_tensors": return_tensors, "return_token_type_ids": return_token_type_ids, "return_attention_mask": return_attention_mask, @@ -2931,6 +2844,12 @@ def __call__( "split_special_tokens": kwargs.pop("split_special_tokens", self.split_special_tokens), "verbose": verbose, } + + if return_tensors in ("tf", "jax"): + logger.warning_once( + "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " + "recommend migrating to PyTorch classes or pinning your version of Transformers." + ) all_kwargs.update(kwargs) if text is None and text_target is None: raise ValueError("You need to specify either `text` or `text_target`.") @@ -2956,15 +2875,16 @@ def __call__( def _call_one( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], - text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]], + text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, + truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -2999,14 +2919,14 @@ def _is_valid_text_input(t): if not _is_valid_text_input(text): raise ValueError( - "text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) " - "or `List[List[str]]` (batch of pretokenized examples)." + "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) " + "or `list[list[str]]` (batch of pretokenized examples)." ) if text_pair is not None and not _is_valid_text_input(text_pair): raise ValueError( - "text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) " - "or `List[List[str]]` (batch of pretokenized examples)." + "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) " + "or `list[list[str]]` (batch of pretokenized examples)." ) if is_split_into_words: @@ -3035,6 +2955,7 @@ def _is_valid_text_input(t): stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, @@ -3057,6 +2978,7 @@ def _is_valid_text_input(t): stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, @@ -3076,11 +2998,12 @@ def encode_plus( text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, + truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -3101,17 +3024,16 @@ def encode_plus( Args: - text (`str`, `List[str]` or `List[int]` (the latter only for not-fast tokenizers)): + text (`str`, `list[str]` or (for non-fast tokenizers) `list[int]`): The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` method). - text_pair (`str`, `List[str]` or `List[int]`, *optional*): + text_pair (`str`, `list[str]` or `list[int]`, *optional*): Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` method). """ - # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( padding=padding, truncation=truncation, @@ -3131,6 +3053,7 @@ def encode_plus( stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, @@ -3154,6 +3077,7 @@ def _encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -3171,20 +3095,21 @@ def _encode_plus( def batch_encode_plus( self, batch_text_or_text_pairs: Union[ - List[TextInput], - List[TextInputPair], - List[PreTokenizedInput], - List[PreTokenizedInputPair], - List[EncodedInput], - List[EncodedInputPair], + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + list[PreTokenizedInputPair], + list[EncodedInput], + list[EncodedInputPair], ], add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, + truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -3206,8 +3131,7 @@ def batch_encode_plus( Args: - batch_text_or_text_pairs (`List[str]`, `List[Tuple[str, str]]`, `List[List[str]]`, `List[Tuple[List[str], List[str]]]`, - and for not-fast tokenizers, also `List[List[int]]`, `List[Tuple[List[int], List[int]]]`): + batch_text_or_text_pairs (`list[str]`, `list[tuple[str, str]]`, `list[list[str]]`, `list[tuple[list[str], list[str]]]`, and for not-fast tokenizers, also `list[list[int]]`, `list[tuple[list[int], list[int]]]`): Batch of sequences or pair of sequences to be encoded. This can be a list of string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see details in `encode_plus`). @@ -3232,6 +3156,7 @@ def batch_encode_plus( stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, @@ -3247,12 +3172,12 @@ def batch_encode_plus( def _batch_encode_plus( self, batch_text_or_text_pairs: Union[ - List[TextInput], - List[TextInputPair], - List[PreTokenizedInput], - List[PreTokenizedInputPair], - List[EncodedInput], - List[EncodedInputPair], + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + list[PreTokenizedInputPair], + list[EncodedInput], + list[EncodedInputPair], ], add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, @@ -3261,6 +3186,7 @@ def _batch_encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -3278,14 +3204,15 @@ def pad( self, encoded_inputs: Union[ BatchEncoding, - List[BatchEncoding], - Dict[str, EncodedInput], - Dict[str, List[EncodedInput]], - List[Dict[str, EncodedInput]], + list[BatchEncoding], + dict[str, EncodedInput], + dict[str, list[EncodedInput]], + list[dict[str, EncodedInput]], ], padding: Union[bool, str, PaddingStrategy] = True, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_attention_mask: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, verbose: bool = True, @@ -3309,23 +3236,23 @@ def pad( Args: - encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `Dict[str, List[int]]`, `Dict[str, List[List[int]]` or `List[Dict[str, List[int]]]`): - Tokenized inputs. Can represent one input ([`BatchEncoding`] or `Dict[str, List[int]]`) or a batch of - tokenized inputs (list of [`BatchEncoding`], *Dict[str, List[List[int]]]* or *List[Dict[str, - List[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader + encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `dict[str, list[int]]`, `dict[str, list[list[int]]` or `list[dict[str, list[int]]]`): + Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of + tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str, + list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader collate function. - Instead of `List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see + Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see the note above for the return type. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). max_length (`int`, *optional*): Maximum length of the returned list and optionally padding length (see above). @@ -3334,6 +3261,9 @@ def pad( This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask (`bool`, *optional*): Whether to return the attention mask. If left to the default, will return the attention mask according to the specific tokenizer's default, defined by the `return_outputs` attribute. @@ -3362,7 +3292,7 @@ def pad( if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} - # The model's main input name, usually `input_ids`, has be passed for padding + # The model's main input name, usually `input_ids`, has been passed for padding if self.model_input_names[0] not in encoded_inputs: raise ValueError( "You should supply an encoding or a list of encodings to this method " @@ -3416,14 +3346,15 @@ def pad( max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_attention_mask=return_attention_mask, ) return BatchEncoding(encoded_inputs, tensor_type=return_tensors) batch_size = len(required_input) - assert all( - len(v) == batch_size for v in encoded_inputs.values() - ), "Some items in the output dictionary have a different batch size than others." + assert all(len(v) == batch_size for v in encoded_inputs.values()), ( + "Some items in the output dictionary have a different batch size than others." + ) if padding_strategy == PaddingStrategy.LONGEST: max_length = max(len(inputs) for inputs in required_input) @@ -3437,6 +3368,7 @@ def pad( max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_attention_mask=return_attention_mask, ) @@ -3448,8 +3380,8 @@ def pad( return BatchEncoding(batch_outputs, tensor_type=return_tensors) def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: """ Create the token type IDs corresponding to the sequences passed. [What are token type IDs?](../glossary#token-type-ids) @@ -3457,19 +3389,23 @@ def create_token_type_ids_from_sequences( Should be overridden in a subclass if the model has a special way of building those. Args: - token_ids_0 (`List[int]`): The first tokenized sequence. - token_ids_1 (`List[int]`, *optional*): The second tokenized sequence. + token_ids_0 (`list[int]`): The first tokenized sequence. + token_ids_1 (`list[int]`, *optional*): The second tokenized sequence. Returns: - `List[int]`: The token type ids. + `list[int]`: The token type ids. """ + cls_len = int(getattr(self, "cls_token_id", None) is not None) + sep_len = int(getattr(self, "sep_token_id", None) is not None) + if token_ids_1 is None: - return len(token_ids_0) * [0] - return [0] * len(token_ids_0) + [1] * len(token_ids_1) + return [0] * (cls_len + len(token_ids_0) + sep_len) + + return [0] * (cls_len + len(token_ids_0) + sep_len) + [1] * (len(token_ids_1) + sep_len) def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. @@ -3477,11 +3413,11 @@ def build_inputs_with_special_tokens( This implementation does not add special tokens and this method should be overridden in a subclass. Args: - token_ids_0 (`List[int]`): The first tokenized sequence. - token_ids_1 (`List[int]`, *optional*): The second tokenized sequence. + token_ids_0 (`list[int]`): The first tokenized sequence. + token_ids_1 (`list[int]`, *optional*): The second tokenized sequence. Returns: - `List[int]`: The model input with special tokens. + `list[int]`: The model input with special tokens. """ if token_ids_1 is None: return token_ids_0 @@ -3490,14 +3426,15 @@ def build_inputs_with_special_tokens( @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def prepare_for_model( self, - ids: List[int], - pair_ids: Optional[List[int]] = None, + ids: list[int], + pair_ids: Optional[list[int]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, + truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -3517,10 +3454,10 @@ def prepare_for_model( overflowing tokens. Such a combination of arguments will raise an error. Args: - ids (`List[int]`): + ids (`list[int]`): Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and `convert_tokens_to_ids` methods. - pair_ids (`List[int]`, *optional*): + pair_ids (`list[int]`, *optional*): Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` and `convert_tokens_to_ids` methods. """ @@ -3611,32 +3548,35 @@ def prepare_for_model( max_length=max_length, padding=padding_strategy.value, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_attention_mask=return_attention_mask, ) if return_length: encoded_inputs["length"] = len(encoded_inputs["input_ids"]) - batch_outputs = BatchEncoding(encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis) + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) return batch_outputs def truncate_sequences( self, - ids: List[int], - pair_ids: Optional[List[int]] = None, + ids: list[int], + pair_ids: Optional[list[int]] = None, num_tokens_to_remove: int = 0, truncation_strategy: Union[str, TruncationStrategy] = "longest_first", stride: int = 0, - ) -> Tuple[List[int], List[int], List[int]]: + ) -> tuple[list[int], list[int], list[int]]: """ Truncates a sequence pair in-place following the strategy. Args: - ids (`List[int]`): + ids (`list[int]`): Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and `convert_tokens_to_ids` methods. - pair_ids (`List[int]`, *optional*): + pair_ids (`list[int]`, *optional*): Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` and `convert_tokens_to_ids` methods. num_tokens_to_remove (`int`, *optional*, defaults to 0): @@ -3661,7 +3601,7 @@ def truncate_sequences( sequence returned. The value of this argument defines the number of additional tokens. Returns: - `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of + `tuple[list[int], list[int], list[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair of sequences (or a batch of pairs) is provided. """ @@ -3747,10 +3687,11 @@ def truncate_sequences( def _pad( self, - encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -3758,7 +3699,7 @@ def _pad( Args: encoded_inputs: - Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`). max_length: maximum length of the returned list and optionally padding length (see below). Will truncate by taking into account the special tokens. padding_strategy: PaddingStrategy to use for padding. @@ -3766,13 +3707,16 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -3796,8 +3740,9 @@ def _pad( if needs_to_be_padded: difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side - if self.padding_side == "right": + if padding_side == "right": if return_attention_mask: encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference if "token_type_ids" in encoded_inputs: @@ -3807,7 +3752,7 @@ def _pad( if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference - elif self.padding_side == "left": + elif padding_side == "left": if return_attention_mask: encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] if "token_type_ids" in encoded_inputs: @@ -3818,17 +3763,17 @@ def _pad( encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input else: - raise ValueError(f"Invalid padding strategy:{self.padding_side}") + raise ValueError(f"Invalid padding strategy:{padding_side}") return encoded_inputs - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: """ Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we often want to remove sub-word tokenization artifacts at the same time. Args: - tokens (`List[str]`): The token to join in a string. + tokens (`list[str]`): The token to join in a string. Returns: `str`: The joined tokens. @@ -3837,16 +3782,16 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: def batch_decode( self, - sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool = None, + clean_up_tokenization_spaces: Optional[bool] = None, **kwargs, - ) -> List[str]: + ) -> list[str]: """ Convert a list of lists of token ids into a list of strings by calling decode. Args: - sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. @@ -3857,7 +3802,7 @@ def batch_decode( Will be passed to the underlying model specific decode method. Returns: - `List[str]`: The list of decoded sentences. + `list[str]`: The list of decoded sentences. """ return [ self.decode( @@ -3871,9 +3816,9 @@ def batch_decode( def decode( self, - token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool = None, + clean_up_tokenization_spaces: Optional[bool] = None, **kwargs, ) -> str: """ @@ -3883,7 +3828,7 @@ def decode( Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: - token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. @@ -3908,24 +3853,24 @@ def decode( def _decode( self, - token_ids: Union[int, List[int]], + token_ids: Union[int, list[int]], skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool = None, + clean_up_tokenization_spaces: Optional[bool] = None, **kwargs, ) -> str: raise NotImplementedError def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. Args: - token_ids_0 (`List[int]`): + token_ids_0 (`list[int]`): List of ids of the first sequence. - token_ids_1 (`List[int]`, *optional*): + token_ids_1 (`list[int]`, *optional*): List of ids of the second sequence. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. @@ -3971,18 +3916,18 @@ def clean_up_tokenization(out_string: str) -> str: ) return out_string - def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool): + def _eventual_warn_about_too_long_sequence(self, ids: list[int], max_length: Optional[int], verbose: bool): """ Depending on the input and internal state we might trigger a warning about a sequence that is too long for its corresponding model Args: - ids (`List[str]`): The ids produced by the tokenization + ids (`list[str]`): The ids produced by the tokenization max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set) verbose (`bool`): Whether or not to print more information and warnings. """ - if max_length is None and len(ids) > self.model_max_length and verbose: + if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0: if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): logger.warning( "Token indices sequence length is longer than the specified maximum sequence length " @@ -4026,11 +3971,7 @@ def register_for_auto_class(cls, auto_class="AutoTokenizer"): Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the library are already mapped with `AutoTokenizer`. - - - This API is experimental and may have some slight breaking changes in the next releases. - Args: auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`): @@ -4048,12 +3989,12 @@ def register_for_auto_class(cls, auto_class="AutoTokenizer"): def prepare_seq2seq_batch( self, - src_texts: List[str], - tgt_texts: Optional[List[str]] = None, + src_texts: list[str], + tgt_texts: Optional[list[str]] = None, max_length: Optional[int] = None, max_target_length: Optional[int] = None, padding: str = "longest", - return_tensors: str = None, + return_tensors: Optional[str] = None, truncation: bool = True, **kwargs, ) -> BatchEncoding: @@ -4061,7 +4002,7 @@ def prepare_seq2seq_batch( Prepare model inputs for translation. For best performance, translate one sentence at a time. Arguments: - src_texts (`List[str]`): + src_texts (`list[str]`): List of documents to summarize or source language texts. tgt_texts (`list`, *optional*): List of summaries or target language texts. @@ -4169,12 +4110,12 @@ def prepare_seq2seq_batch( return model_inputs -def get_fast_tokenizer_file(tokenization_files: List[str]) -> str: +def get_fast_tokenizer_file(tokenization_files: list[str]) -> str: """ Get the tokenization file to use for this version of transformers. Args: - tokenization_files (`List[str]`): The list of available configuration files. + tokenization_files (`list[str]`): The list of available configuration files. Returns: `str`: The tokenization file to use. From 509a308c931a5b5318b579e7e90ef32381642bf9 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:50:03 +0800 Subject: [PATCH 31/94] update image_transforms.py --- mindone/transformers/image_transforms.py | 158 +++++++++++++++++------ 1 file changed, 119 insertions(+), 39 deletions(-) diff --git a/mindone/transformers/image_transforms.py b/mindone/transformers/image_transforms.py index 8ff53d7f5e..8fa9b969e4 100644 --- a/mindone/transformers/image_transforms.py +++ b/mindone/transformers/image_transforms.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings +from collections import defaultdict from collections.abc import Collection, Iterable from math import ceil from typing import Optional, Union @@ -48,7 +48,9 @@ def to_channel_dimension_format( input_channel_dim: Optional[Union[ChannelDimension, str]] = None, ) -> np.ndarray: """ - Converts `image` to the channel dimension format specified by `channel_dim`. + Converts `image` to the channel dimension format specified by `channel_dim`. The input + can have arbitrary number of leading dimensions. Only last three dimension will be permuted + to format the `image`. Args: image (`numpy.ndarray`): @@ -72,9 +74,11 @@ def to_channel_dimension_format( return image if target_channel_dim == ChannelDimension.FIRST: - image = image.transpose((2, 0, 1)) + axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2] + image = image.transpose(axes) elif target_channel_dim == ChannelDimension.LAST: - image = image.transpose((1, 2, 0)) + axes = list(range(image.ndim - 3)) + [image.ndim - 2, image.ndim - 1, image.ndim - 3] + image = image.transpose(axes) else: raise ValueError(f"Unsupported channel dimension format: {channel_dim}") @@ -399,7 +403,7 @@ def normalize( The channel dimension format of the input image. If unset, will use the inferred format from the input. """ if not isinstance(image, np.ndarray): - raise ValueError("image must be a numpy array") + raise TypeError("image must be a numpy array") if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -440,7 +444,6 @@ def center_crop( size: tuple[int, int], data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, - return_numpy: Optional[bool] = None, ) -> np.ndarray: """ Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to @@ -461,21 +464,11 @@ def center_crop( - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. If unset, will use the inferred format of the input image. - return_numpy (`bool`, *optional*): - Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the - previous ImageFeatureExtractionMixin method. - - Unset: will return the same type as the input image. - - `True`: will return a numpy array. - - `False`: will return a `PIL.Image.Image` object. Returns: `np.ndarray`: The cropped image. """ requires_backends(center_crop, ["vision"]) - if return_numpy is not None: - warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning) - - return_numpy = True if return_numpy is None else return_numpy if not isinstance(image, np.ndarray): raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") @@ -528,8 +521,6 @@ def center_crop( new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)] new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST) - if not return_numpy: - new_image = to_pil_image(new_image) return new_image @@ -726,7 +717,7 @@ def _expand_for_data_format(values): values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0)) # Add additional padding if there's a batch dimension - values = (0, *values) if image.ndim == 4 else values + values = ((0, 0), *values) if image.ndim == 4 else values return values padding = _expand_for_data_format(padding) @@ -811,38 +802,127 @@ def _cast_tensor_to_float(x): return x return x.float() +def _group_images_by_shape(nested_images, is_nested: bool = False): + """Helper function to flatten a single level of nested image structures and group by shape.""" + grouped_images = defaultdict(list) + grouped_images_index = {} + nested_images = [nested_images] if not is_nested else nested_images + for i, sublist in enumerate(nested_images): + for j, image in enumerate(sublist): + key = (i, j) if is_nested else j + shape = image.shape[1:] + grouped_images[shape].append(image) + grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1) + + return grouped_images, grouped_images_index + + +def _reconstruct_nested_structure(indices, processed_images): + """Helper function to reconstruct a single level nested structure.""" + # Find the maximum outer index + max_outer_idx = max(idx[0] for idx in indices.keys()) + + # Create the outer list + result = [None] * (max_outer_idx + 1) + + # Group indices by outer index + nested_indices = defaultdict(list) + for i, j in indices.keys(): + nested_indices[i].append(j) + + for i in range(max_outer_idx + 1): + if i in nested_indices: + inner_max_idx = max(nested_indices[i]) + inner_list = [None] * (inner_max_idx + 1) + for j in range(inner_max_idx + 1): + if (i, j) in indices: + shape, idx = indices[(i, j)] + inner_list[j] = processed_images[shape][idx] + result[i] = inner_list + + return result def group_images_by_shape( - images: list["ms.Tensor"], -) -> tuple[dict[tuple[int, int], list["ms.Tensor"]], dict[int, tuple[tuple[int, int], int]]]: + images: Union[list["ms.Tensor"], "ms.Tensor"], + disable_grouping: bool, + is_nested: bool = False, +) -> tuple[ + dict[tuple[int, int], list["ms.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]] +]: """ Groups images by shape. Returns a dictionary with the shape as key and a list of images with that shape as value, and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value. - """ - grouped_images = {} - grouped_images_index = {} - for i, image in enumerate(images): - shape = image.shape[1:] - if shape not in grouped_images: - grouped_images[shape] = [] - grouped_images[shape].append(image) - grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1) - # stack images with the same shape - grouped_images = {shape: mint.stack(images, dim=0) for shape, images in grouped_images.items()} + + The function supports both flat lists of tensors and nested structures. + The input must be either all flat or all nested, not a mix of both. + + Args: + images (Union[list["ms.Tensor"], "ms.Tensor"]): + A list of images or a single tensor + disable_grouping (bool): + Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise. + This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157 + is_nested (bool, *optional*, defaults to False): + Whether the images are nested. + + Returns: + tuple[dict[tuple[int, int], list["ms.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]: + - A dictionary with shape as key and list of images with that shape as value + - A dictionary mapping original indices to (shape, index) tuples + """ + # If disable grouping is not explicitely provided, we favor disabling it if the images are on CPU, and enabling it otherwise. + if disable_grouping is None: + device = images[0][0].device if is_nested else images[0].device + disable_grouping = device == "cpu" + + if disable_grouping: + if is_nested: + return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, { + (i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i])) + } + else: + return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))} + + # Handle single level nested structure + grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested) + + # Stack images with the same shape + grouped_images = {shape: mint.stack(images_list, dim=0) for shape, images_list in grouped_images.items()} + return grouped_images, grouped_images_index def reorder_images( - processed_images: dict[tuple[int, int], "ms.Tensor"], grouped_images_index: dict[int, tuple[int, int]] -) -> list["ms.Tensor"]: + processed_images: dict[tuple[int, int], "ms.Tensor"], + grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]], + is_nested: bool = False, +) -> Union[list["ms.Tensor"], "ms.Tensor"]: """ - Reconstructs a list of images in the original order. + Reconstructs images in the original order, preserving the original structure (nested or not). + The input structure is either all flat or all nested. + + Args: + processed_images (dict[tuple[int, int], "ms.Tensor"]): + Dictionary mapping shapes to batched processed images. + grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]): + Dictionary mapping original indices to (shape, index) tuples. + is_nested (bool, *optional*, defaults to False): + Whether the images are nested. Cannot be infered from the input, as some processing functions outputs nested images. + even with non nested images,e.g functions splitting images into patches. We thus can't deduce is_nested from the input. + + + Returns: + Union[list["ms.Tensor"], "ms.Tensor"]: + Images in the original structure. """ - return [ - processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] - for i in range(len(grouped_images_index)) - ] + if not is_nested: + return [ + processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] + for i in range(len(grouped_images_index)) + ] + + return _reconstruct_nested_structure(grouped_images_index, processed_images) class NumpyToTensor: From 33ed2beab2170a549d1c6dfa4b5ea593cb5efab8 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:58:12 +0800 Subject: [PATCH 32/94] update video_utils.py and image_utils.py --- mindone/transformers/image_utils.py | 460 +++------------------------- mindone/transformers/video_utils.py | 47 +++ 2 files changed, 83 insertions(+), 424 deletions(-) diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py index 6d8835952d..14246083ca 100644 --- a/mindone/transformers/image_utils.py +++ b/mindone/transformers/image_utils.py @@ -18,24 +18,14 @@ import base64 import os from collections.abc import Iterable -from contextlib import redirect_stdout from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Callable, Optional, Union - +from typing import TYPE_CHECKING, Optional, Union +from mindspore import mint import numpy as np import requests from packaging import version -from transformers import is_av_available -from transformers.utils import is_cv2_available, is_decord_available, is_yt_dlp_available, logging -from transformers.utils.constants import ( # noqa: F401 - IMAGENET_DEFAULT_MEAN, - IMAGENET_DEFAULT_STD, - IMAGENET_STANDARD_MEAN, - IMAGENET_STANDARD_STD, - OPENAI_CLIP_MEAN, - OPENAI_CLIP_STD, -) +from transformers.utils import logging from .utils import ( ExplicitEnum, @@ -85,18 +75,6 @@ ] # noqa -VideoInput = Union[ - list["PIL.Image.Image"], - "np.ndarray", - "mindspore.Tensor", - list["np.ndarray"], - list["mindspore.Tensor"], - list[list["PIL.Image.Image"]], - list[list["np.ndarrray"]], - list[list["mindspore.Tensor"]], -] # noqa - - class ChannelDimension(ExplicitEnum): FIRST = "channels_first" LAST = "channels_last" @@ -112,13 +90,6 @@ class AnnotionFormat(ExplicitEnum): COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value -@dataclass -class VideoMetadata: - total_num_frames: int - fps: float - duration: float - video_backend: str - AnnotationType = dict[str, Union[int, str, list[dict]]] @@ -152,6 +123,13 @@ def is_valid_image(img): def is_valid_list_of_images(images: list): return images and all(is_valid_image(image) for image in images) +def concatenate_list(input_list): + if isinstance(input_list[0], list): + return [item for sublist in input_list for item in sublist] + elif isinstance(input_list[0], np.ndarray): + return np.concatenate(input_list, axis=0) + elif isinstance(input_list[0], mint.Tensor): + return mint.cat(input_list, dim=0) def valid_images(imgs): # If we have an list of images, make sure every image is valid @@ -223,13 +201,16 @@ def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]: def make_flat_list_of_images( images: Union[list[ImageInput], ImageInput], + expected_ndims: int = 3, ) -> ImageInput: """ Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1. If the input is a nested list of images, it is converted to a flat list of images. Args: - images (`Union[List[ImageInput], ImageInput]`): + images (`Union[list[ImageInput], ImageInput]`): The input image. + expected_ndims (`int`, *optional*, defaults to 3): + The expected number of dimensions for a single input image. Returns: list: A list of images or a 4d array of images. """ @@ -242,15 +223,15 @@ def make_flat_list_of_images( return [img for img_list in images for img in img_list] if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): - if is_pil_image(images[0]) or images[0].ndim == 3: + if is_pil_image(images[0]) or images[0].ndim == expected_ndims: return images - if images[0].ndim == 4: + if images[0].ndim == expected_ndims + 1: return [img for img_list in images for img in img_list] if is_valid_image(images): - if is_pil_image(images) or images.ndim == 3: + if is_pil_image(images) or images.ndim == expected_ndims: return [images] - if images.ndim == 4: + if images.ndim == expected_ndims + 1: return list(images) raise ValueError(f"Could not make a flat list of images from {images}") @@ -258,12 +239,15 @@ def make_flat_list_of_images( def make_nested_list_of_images( images: Union[list[ImageInput], ImageInput], + expected_ndims: int = 3, ) -> ImageInput: """ Ensure that the output is a nested list of images. Args: - images (`Union[List[ImageInput], ImageInput]`): + images (`Union[list[ImageInput], ImageInput]`): The input image. + expected_ndims (`int`, *optional*, defaults to 3): + The expected number of dimensions for a single input image. Returns: list: A list of list of images or a list of 4d array of images. """ @@ -277,52 +261,21 @@ def make_nested_list_of_images( # If it's a list of images, it's a single batch, so convert it to a list of lists if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): - if is_pil_image(images[0]) or images[0].ndim == 3: + if is_pil_image(images[0]) or images[0].ndim == expected_ndims: return [images] - if images[0].ndim == 4: + if images[0].ndim == expected_ndims + 1: return [list(image) for image in images] # If it's a single image, convert it to a list of lists if is_valid_image(images): - if is_pil_image(images) or images.ndim == 3: + if is_pil_image(images) or images.ndim == expected_ndims: return [[images]] - if images.ndim == 4: + if images.ndim == expected_ndims + 1: return [list(images)] raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.") -def make_batched_videos(videos) -> VideoInput: - """ - Ensure that the input is a list of videos. - Args: - videos (`VideoInput`): - Video or videos to turn into a list of videos. - Returns: - list: A list of videos. - """ - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - # case 1: nested batch of videos so we flatten it - if not is_pil_image(videos[0][0]) and videos[0][0].ndim == 4: - videos = [[video for batch_list in batched_videos for video in batch_list] for batched_videos in videos] - # case 2: list of videos represented as list of video frames - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if is_pil_image(videos[0]) or videos[0].ndim == 3: - return [videos] - elif videos[0].ndim == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos): - if is_pil_image(videos) or videos.ndim == 3: - return [[videos]] - elif videos.ndim == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - def to_numpy_array(img) -> np.ndarray: if not is_valid_image(img): raise ValueError(f"Invalid image type: {type(img)}") @@ -331,7 +284,6 @@ def to_numpy_array(img) -> np.ndarray: return np.array(img) return to_numpy(img) - def pil_to_tensor(image, is_normalize=True): """ Pillow image to mindspore tensor @@ -374,12 +326,14 @@ def infer_channel_dimension_format( first_dim, last_dim = 0, 2 elif image.ndim == 4: first_dim, last_dim = 1, 3 + elif image.ndim == 5: + first_dim, last_dim = 2, 4 else: raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels: logger.warning( - f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension." + f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension." ) return ChannelDimension.FIRST elif image.shape[first_dim] in num_channels: @@ -388,7 +342,6 @@ def infer_channel_dimension_format( return ChannelDimension.LAST raise ValueError("Unable to infer channel dimension format") - def get_channel_dimension_axis( image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None ) -> int: @@ -554,347 +507,6 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = return image -def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs): - """ - A default sampling function that replicates the logic used in get_uniform_frame_indices, - while optionally handling `fps` if `num_frames` is not provided. - - Args: - metadata (`VideoMetadata`): - `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps". - num_frames (`int`, *optional*): - Number of frames to sample uniformly. - fps (`int`, *optional*): - Desired frames per second. Takes priority over num_frames if both are provided. - - Returns: - `np.ndarray`: Array of frame indices to sample. - """ - total_num_frames = metadata.total_num_frames - video_fps = metadata.fps - - # If num_frames is not given but fps is, calculate num_frames from fps - if num_frames is None and fps is not None: - num_frames = int(total_num_frames / video_fps * fps) - if num_frames > total_num_frames: - raise ValueError( - f"When loading the video with fps={fps}, we computed num_frames={num_frames} " - f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata." - ) - - if num_frames is not None: - indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int) - else: - indices = np.arange(0, total_num_frames, dtype=int) - return indices - - -def read_video_opencv( - video_path: str, - sample_indices_fn: Callable, - **kwargs, -): - """ - Decode a video using the OpenCV backend. - - Args: - video_path (`str`): - Path to the video file. - sample_indices_fn (`Callable`): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniform sampling with fps is performed. - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - Tuple[`np.array`, `VideoMetadata`]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - `VideoMetadata` object. - """ - # Lazy import cv2 - requires_backends(read_video_opencv, ["cv2"]) - import cv2 - - video = cv2.VideoCapture(video_path) - total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - video_fps = video.get(cv2.CAP_PROP_FPS) - duration = total_num_frames / video_fps if video_fps else 0 - metadata = VideoMetadata( - total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv" - ) - indices = sample_indices_fn(metadata=metadata, **kwargs) - - index = 0 - frames = [] - while video.isOpened(): - success, frame = video.read() - if not success: - break - if index in indices: - height, width, channel = frame.shape - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frames.append(frame[0:height, 0:width, 0:channel]) - if success: - index += 1 - if index >= total_num_frames: - break - - video.release() - metadata.frames_indices = indices - return np.stack(frames), metadata - - -def read_video_decord( - video_path: str, - sample_indices_fn: Optional[Callable] = None, - **kwargs, -): - """ - Decode a video using the Decord backend. - - Args: - video_path (`str`): - Path to the video file. - sample_indices_fn (`Callable`, *optional*): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniform sampling with fps is performed. - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - Tuple[`np.array`, `VideoMetadata`]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - `VideoMetadata` object. - """ - # Lazy import from decord - requires_backends(read_video_decord, ["decord"]) - from decord import VideoReader, cpu - - vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu - video_fps = vr.get_avg_fps() - total_num_frames = len(vr) - duration = total_num_frames / video_fps if video_fps else 0 - metadata = VideoMetadata( - total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord" - ) - - indices = sample_indices_fn(metadata=metadata, **kwargs) - - frames = vr.get_batch(indices).asnumpy() - metadata.frames_indices = indices - return frames, metadata - - -def read_video_pyav( - video_path: str, - sample_indices_fn: Callable, - **kwargs, -): - """ - Decode the video with PyAV decoder. - - Args: - video_path (`str`): - Path to the video file. - sample_indices_fn (`Callable`, *optional*): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniform sampling with fps is performed. - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - Tuple[`np.array`, `VideoMetadata`]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - `VideoMetadata` object. - """ - # Lazy import av - requires_backends(read_video_pyav, ["av"]) - import av - - container = av.open(video_path) - total_num_frames = container.streams.video[0].frames - video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`? - duration = total_num_frames / video_fps if video_fps else 0 - metadata = VideoMetadata( - total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav" - ) - indices = sample_indices_fn(metadata=metadata, **kwargs) - - frames = [] - container.seek(0) - end_index = indices[-1] - for i, frame in enumerate(container.decode(video=0)): - if i > end_index: - break - if i >= 0 and i in indices: - frames.append(frame) - - video = np.stack([x.to_ndarray(format="rgb24") for x in frames]) - metadata.frames_indices = indices - return video, metadata - - -def read_video_mindspore( - video_path: str, - sample_indices_fn: Callable, - **kwargs, -): - """ - Decode the video with mindspore.dataset decoder. - - Args: - video_path (`str`): - Path to the video file. - sample_indices_fn (`Callable`, *optional*): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniform sampling with fps is performed. - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - tuple[`np.array`, `VideoMetadata`]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - `VideoMetadata` object. - """ - video, _, info = ms.dataset.vision.read_video( - video_path, - start_pts=0.0, - end_pts=None, - pts_unit="sec", - ) - video_fps = info["video_fps"] - total_num_frames = video.size(0) - duration = total_num_frames / video_fps if video_fps else 0 - metadata = VideoMetadata( - total_num_frames=int(total_num_frames), - fps=float(video_fps), - duration=float(duration), - video_backend="mindspore", - ) - - indices = sample_indices_fn(metadata=metadata, **kwargs) - - video = video[indices].contiguous().numpy() - metadata.frames_indices = indices - return video, metadata - - -VIDEO_DECODERS = { - "decord": read_video_decord, - "opencv": read_video_opencv, - "pyav": read_video_pyav, - "torchvision": read_video_mindspore, -} - - -def load_video( - video: Union[str, "VideoInput"], - num_frames: Optional[int] = None, - fps: Optional[int] = None, - backend: str = "opencv", - sample_indices_fn: Optional[Callable] = None, - **kwargs, -) -> np.array: - """ - Loads `video` to a numpy array. - - Args: - video (`str` or `VideoInput`): - The video to convert to the numpy array format. Can be a link to video or local path. - num_frames (`int`, *optional*): - Number of frames to sample uniformly. If not passed, the whole video is loaded. - fps (`int`, *optional*): - Number of frames to sample per second. Should be passed only when `num_frames=None`. - If not specified and `num_frames==None`, all frames are sampled. - backend (`str`, *optional*, defaults to `"opencv"`): - The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv". - sample_indices_fn (`Callable`, *optional*): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args. - The function expects at input the all args along with all kwargs passed to `load_video` and should output valid - indices at which the video should be sampled. For example: - - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - tuple[`np.array`, Dict]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - Metadata dictionary. - """ - - # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn` - if fps is not None and num_frames is not None and sample_indices_fn is None: - raise ValueError( - "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!" - ) - - # If user didn't pass a sampling function, create one on the fly with default logic - if sample_indices_fn is None: - - def sample_indices_fn_func(metadata, **fn_kwargs): - return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs) - - sample_indices_fn = sample_indices_fn_func - - if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"): - if not is_yt_dlp_available(): - raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.") - # Lazy import from yt_dlp - requires_backends(load_video, ["yt_dlp"]) - from yt_dlp import YoutubeDL - - buffer = BytesIO() - with redirect_stdout(buffer), YoutubeDL() as f: - f.download([video]) - bytes_obj = buffer.getvalue() - file_obj = BytesIO(bytes_obj) - elif video.startswith("http://") or video.startswith("https://"): - file_obj = BytesIO(requests.get(video).content) - elif os.path.isfile(video): - file_obj = video - elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])): - file_obj = None - else: - raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.") - - # can also load with decord, but not cv2/torchvision - # both will fail in case of url links - video_is_url = video.startswith("http://") or video.startswith("https://") - if video_is_url and backend in ["opencv", "mindspore"]: - raise ValueError( - "If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend" - ) - - if file_obj is None: - return video - - if ( - (not is_decord_available() and backend == "decord") - or (not is_av_available() and backend == "pyav") - or (not is_cv2_available() and backend == "opencv") - or (not is_mindspore_available() and backend == "torchvision") - ): - raise ImportError( - f"You chose backend={backend} for loading the video but the required library is not found in your environment " - f"Make sure to install {backend} before loading the video." - ) - - video_decoder = VIDEO_DECODERS[backend] - video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs) - return video, metadata - - def load_images( images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None ) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]: @@ -1147,7 +759,7 @@ def resize(self, image, size, resample=None, default_to_square=True, max_size=No default_to_square (`bool`, *optional*, defaults to `True`): How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square (`size`,`size`). If set to `False`, will replicate - [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize) + [`torchvision.transforms.Resize`](https://pymint.org/vision/stable/transforms.html#torchvision.transforms.Resize) with support for resizing only the smallest edge and providing an optional `max_size`. max_size (`int`, *optional*, defaults to `None`): The maximum allowed for the longer edge of the resized image: if the longer edge of the image is @@ -1351,12 +963,12 @@ class SizeDict: Hashable dictionary to store image size information. """ - height: int = None - width: int = None - longest_edge: int = None - shortest_edge: int = None - max_height: int = None - max_width: int = None + height: Optional[int] = None + width: Optional[int] = None + longest_edge: Optional[int] = None + shortest_edge: Optional[int] = None + max_height: Optional[int] = None + max_width: Optional[int] = None def __getitem__(self, key): if hasattr(self, key): diff --git a/mindone/transformers/video_utils.py b/mindone/transformers/video_utils.py index a07e168c8f..4106f04553 100644 --- a/mindone/transformers/video_utils.py +++ b/mindone/transformers/video_utils.py @@ -397,12 +397,59 @@ def sample_indices_fn(metadata, **kwargs): return video, metadata +def read_video_mindspore( + video_path: str, + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode the video with mindspore.dataset decoder. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + video, _, info = ms.dataset.vision.read_video( + video_path, + start_pts=0.0, + end_pts=None, + pts_unit="sec", + ) + video_fps = info["video_fps"] + total_num_frames = video.size(0) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), + fps=float(video_fps), + duration=float(duration), + video_backend="mindspore", + ) + + indices = sample_indices_fn(metadata=metadata, **kwargs) + + video = video[indices].contiguous().numpy() + metadata.frames_indices = indices + return video, metadata VIDEO_DECODERS = { "decord": read_video_decord, "opencv": read_video_opencv, "pyav": read_video_pyav, + "torchvision": read_video_mindspore, + "mindspore": read_video_mindspore, } From 0d8142cf8870b39fd57f68b5ec3488047fcca702 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 20 Aug 2025 20:20:52 +0800 Subject: [PATCH 33/94] update image_utils.py & image_processing_utils_fast.py --- .../image_processing_utils_fast.py | 237 ++++++++---------- mindone/transformers/image_utils.py | 9 +- 2 files changed, 109 insertions(+), 137 deletions(-) diff --git a/mindone/transformers/image_processing_utils_fast.py b/mindone/transformers/image_processing_utils_fast.py index 8c5b6bb162..8af174e6e1 100644 --- a/mindone/transformers/image_processing_utils_fast.py +++ b/mindone/transformers/image_processing_utils_fast.py @@ -18,12 +18,12 @@ from collections.abc import Iterable from functools import lru_cache, partial from typing import Any, Optional, TypedDict, Union - +from copy import deepcopy import numpy as np from PIL import Image -from transformers.utils import add_start_docstrings, logging from mindspore import mint +import mindspore as ms from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from .image_transforms import ( @@ -49,6 +49,8 @@ ) from .processing_utils import Unpack from .utils import TensorType, is_mindspore_available, is_mindspore_tensor, is_vision_available +from transformers.utils import auto_docstring, logging + if is_vision_available(): from .image_utils import PILImageResampling @@ -174,103 +176,8 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False): input_data_format: Optional[Union[str, ChannelDimension]] -BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r""" - - Args: - do_resize (`bool`, *optional*, defaults to `self.do_resize`): - Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the - `do_resize` parameter in the `preprocess` method. - size (`dict`, *optional*, defaults to `self.size`): - Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` - method. - default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): - Whether to default to a square image when resizing, if size is an int. - resample (`PILImageResampling`, *optional*, defaults to `self.resample`): - Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be - overridden by the `resample` parameter in the `preprocess` method. - do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): - Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the - `preprocess` method. - crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`): - Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` - method. - do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): - Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the - `do_rescale` parameter in the `preprocess` method. - rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): - Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be - overridden by the `rescale_factor` parameter in the `preprocess` method. - do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): - Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` - method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. - image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): - Mean to use if normalizing the image. This is a float or list of floats the length of the number of - channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be - overridden by the `image_mean` parameter in the `preprocess` method. - image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): - Standard deviation to use if normalizing the image. This is a float or list of floats the length of the - number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. - Can be overridden by the `image_std` parameter in the `preprocess` method. - do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): - Whether to convert the image to RGB. - return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): - Returns stacked tensors if set to `pt, otherwise returns a list of tensors. - data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): - Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. - input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): - The channel dimension format for the input image. If unset, the channel dimension format is inferred - from the input image. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.""" - -BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r""" - Preprocess an image or batch of images. - Args: - images (`ImageInput`): - Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If - passing in images with pixel values between 0 and 1, set `do_rescale=False`. - do_resize (`bool`, *optional*, defaults to `self.do_resize`): - Whether to resize the image. - size (`Dict[str, int]`, *optional*, defaults to `self.size`): - Describes the maximum input dimensions to the model. - resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`): - Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only - has an effect if `do_resize` is set to `True`. - do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): - Whether to center crop the image. - crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): - Size of the output image after applying `center_crop`. - do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): - Whether to rescale the image. - rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): - Rescale factor to rescale the image by if `do_rescale` is set to `True`. - do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): - Whether to normalize the image. - image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): - Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. - image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): - Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to - `True`. - do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): - Whether to convert the image to RGB. - return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): - Returns stacked tensors if set to `pt, otherwise returns a list of tensors. - data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): - Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. - input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): - The channel dimension format for the input image. If unset, the channel dimension format is inferred - from the input image. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.""" - - -@add_start_docstrings( - "Constructs a fast base image processor.", - BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, -) +@auto_docstring class BaseImageProcessorFast(BaseImageProcessor): resample = None image_mean = None @@ -310,7 +217,11 @@ def __init__( if kwarg is not None: setattr(self, key, kwarg) else: - setattr(self, key, getattr(self, key, None)) + setattr(self, key, deepcopy(getattr(self, key, None))) + + # get valid kwargs names + self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys()) + def resize( self, @@ -328,7 +239,7 @@ def resize( Image to resize. size (`SizeDict`): Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. - resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. Returns: @@ -367,6 +278,18 @@ def resize( image = Image.fromarray(image) return ms.tensor(np.array(resize(image))).permute(2, 0, 1) + @staticmethod + def compile_friendly_resize( + image: "ms.Tensor", + new_size: tuple[int, int], + interpolation: Optional["F.InterpolationMode"] = None, + antialias: bool = True, + ) -> "ms.Tensor": + """ + A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor. + """ + raise NotImplementedError("This method is not implemented for mindspore") + def rescale( self, image: "ms.Tensor", @@ -516,6 +439,7 @@ def filter_out_unused_kwargs(self, kwargs: dict): def _prepare_images_structure( self, images: ImageInput, + expected_ndims: int = 3, ) -> ImageInput: """ Prepare the images structure for processing. @@ -527,7 +451,7 @@ def _prepare_images_structure( Returns: `ImageInput`: The images with a valid nesting. """ - return make_flat_list_of_images(images) + return make_flat_list_of_images(images, expected_ndims=expected_ndims) def _process_image( self, @@ -549,6 +473,9 @@ def _process_image( # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays image = ms.from_numpy(image).contiguous() + # If the image is 2D, we need to unsqueeze it to add a channel dimension for processing + if image.ndim == 2: + image = image.unsqueeze(0) # Infer the channel dimension format if not provided if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -559,28 +486,47 @@ def _process_image( return image - def _prepare_input_images( + def _prepare_image_like_inputs( self, images: ImageInput, - do_convert_rgb: bool = None, + do_convert_rgb: Optional[bool] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, + expected_ndims: int = 3, ) -> list["ms.Tensor"]: """ - Prepare the input images for processing. + Prepare image-like inputs for processing. + + Args: + images (`ImageInput`): + The image-like inputs to process. + do_convert_rgb (`bool`, *optional*): + Whether to convert the images to RGB. + input_data_format (`str` or `ChannelDimension`, *optional*): + The input data format of the images. + expected_ndims (`int`, *optional*): + The expected number of dimensions for the images. (can be 2 for segmentation maps etc.) + + Returns: + List[`ms.Tensor`]: The processed images. """ - images = self._prepare_images_structure(images) - process_image_fn = partial( - self._process_image, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, + + # Get structured images (potentially nested) + images = self._prepare_images_structure(images, expected_ndims=expected_ndims) + + process_image_partial = partial( + self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format ) - # todo: yoni - check if we can parallelize this efficiently - processed_images = [] - for image in images: - processed_images.append(process_image_fn(image)) - return processed_images + # Check if we have nested structure, assuming the nesting is consistent + has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple)) + + if has_nested_structure: + processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images] + else: + processed_images = [process_image_partial(img) for img in images] + return processed_images + def _further_process_kwargs( self, size: Optional[SizeDict] = None, @@ -651,22 +597,21 @@ def _validate_preprocess_kwargs( data_format=data_format, ) - @add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS) - def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: - logger.warning("Please use FastImageProcessor cautiously. It may not have better inference performance!") - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) + def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: + return self.preprocess(images, *args, **kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: + # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names) # Set default kwargs from self. This ensures that if a kwarg is not provided # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_kwargs.__annotations__: + for kwarg_name in self._valid_kwargs_names: kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) # Extract parameters that are only used for preparing the input images do_convert_rgb = kwargs.pop("do_convert_rgb") input_data_format = kwargs.pop("input_data_format") - # Prepare input images - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format - ) # Update kwargs that need further processing before being validated kwargs = self._further_process_kwargs(**kwargs) @@ -676,17 +621,40 @@ def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProces # torch resize uses interpolation instead of resample resample = kwargs.pop("resample") + + # Check if resample is an int before checking if it's an instance of PILImageResampling + # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module. + # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`. kwargs["interpolation"] = ( - pil_mindspore_interpolation_mapping[resample] - if isinstance(resample, (PILImageResampling, int)) - else resample + pil_mindspore_interpolation_mapping[resample] if isinstance(resample, (int, PILImageResampling)) else resample ) # Pop kwargs that are not needed in _preprocess kwargs.pop("default_to_square") kwargs.pop("data_format") - return self._preprocess(images=images, **kwargs) + return self._preprocess_image_like_inputs( + images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, **kwargs + ) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + *args, + do_convert_rgb: bool, + input_data_format: ChannelDimension, + **kwargs: Unpack[DefaultFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + To be overriden by subclasses when image-like inputs other than images should be processed. + It can be used for segmentation maps, depth maps, etc. + """ + # Prepare input images + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format + ) + return self._preprocess(images, *args, **kwargs) def _preprocess( self, @@ -701,25 +669,22 @@ def _preprocess( do_normalize: bool, image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], **kwargs, ) -> BatchFeature: # Group images by size for batched resizing - grouped_images, grouped_images_index = group_images_by_shape(images) + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) resized_images_grouped = {} for shape, stacked_images in grouped_images.items(): if do_resize: - stacked_images_updated = [] - for i in range(len(stacked_images)): - stacked_images_updated.append( - self.resize(image=stacked_images[i], size=size, interpolation=interpolation) - ) - resized_images_grouped[shape] = stacked_images_updated + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images resized_images = reorder_images(resized_images_grouped, grouped_images_index) # Group images by size for further processing # Needed in case do_resize is False, or resize returns images with different sizes - grouped_images, grouped_images_index = group_images_by_shape(resized_images) + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) processed_images_grouped = {} for shape, stacked_images in grouped_images.items(): if do_center_crop: @@ -738,9 +703,9 @@ def _preprocess( def to_dict(self): encoder_dict = super().to_dict() encoder_dict.pop("_valid_processor_keys", None) + encoder_dict.pop("_valid_kwargs_names", None) return encoder_dict - class SemanticSegmentationMixin: def post_process_semantic_segmentation(self, outputs, target_sizes: list[tuple] = None): """ diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py index 14246083ca..a3c5d07da1 100644 --- a/mindone/transformers/image_utils.py +++ b/mindone/transformers/image_utils.py @@ -36,7 +36,14 @@ requires_backends, to_numpy, ) - +from .utils.constants import ( # noqa: F401 + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, +) if is_vision_available(): import PIL.Image import PIL.ImageOps From 991a783979fc5ec48c9771211e33117907b0cb60 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Aug 2025 11:58:59 +0800 Subject: [PATCH 34/94] update integration sdpa_attention.py --- .../transformers/integrations/sdpa_attention.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mindone/transformers/integrations/sdpa_attention.py b/mindone/transformers/integrations/sdpa_attention.py index 5f61df46f8..0c1597eac6 100644 --- a/mindone/transformers/integrations/sdpa_attention.py +++ b/mindone/transformers/integrations/sdpa_attention.py @@ -21,6 +21,10 @@ def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand((batch, num_key_value_heads, n_rep, slen, head_dim)) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def use_gqa_in_sdpa(attention_mask: Optional[ms.Tensor], key: ms.Tensor) -> bool: + # GQA is not supported yet. + return False + def sdpa_attention_forward( module: nn.Cell, @@ -38,10 +42,13 @@ def sdpa_attention_forward( "`sdpa` attention does not support `output_attentions=True` or `head_mask`." " Please set your attention to `eager` if you want any of these features." ) - + sdpa_kwargs = {} if hasattr(module, "num_key_value_groups"): - key = repeat_kv(key, module.num_key_value_groups) - value = repeat_kv(value, module.num_key_value_groups) + if not use_gqa_in_sdpa(attention_mask, key): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + else: + sdpa_kwargs = {"enable_gqa": True} if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -73,6 +80,7 @@ def sdpa_attention_forward( atten_mask=attention_mask, scale=scaling, keep_prob=1 - dropout, + **sdpa_kwargs, )[0] attn_output = mint.transpose(attn_output, 1, 2).contiguous() From ccb08973f6f21cc7de49e89f620445ed3b50c3b0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Aug 2025 12:14:01 +0800 Subject: [PATCH 35/94] update mask_utils.py --- mindone/transformers/masking_utils.py | 564 +++++++++++++++++++++++++- 1 file changed, 549 insertions(+), 15 deletions(-) diff --git a/mindone/transformers/masking_utils.py b/mindone/transformers/masking_utils.py index 25fb7c64eb..7f8dfe3d93 100644 --- a/mindone/transformers/masking_utils.py +++ b/mindone/transformers/masking_utils.py @@ -15,6 +15,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools from typing import Callable, Optional, Union from transformers.configuration_utils import PretrainedConfig @@ -75,6 +76,16 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return inner_mask +def chunked_overlay(chunk_size: int) -> Callable: + """ + This is an overlay depicting a chuned attention pattern. Add it on top of a causal mask for a proper chunked + attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return kv_idx // chunk_size == q_idx // chunk_size + + return inner_mask def sliding_window_causal_mask_function(sliding_window: int) -> Callable: """ @@ -82,6 +93,48 @@ def sliding_window_causal_mask_function(sliding_window: int) -> Callable: """ return and_masks(sliding_window_overlay(sliding_window), causal_mask_function) +def chunked_causal_mask_function(chunk_size: int) -> Callable: + """ + This return the mask_function function to create a chunked attention mask. + """ + return and_masks(chunked_overlay(chunk_size), causal_mask_function) + + +def padding_mask_function(padding_mask: ms.Tensor) -> Callable: + """ + This return the mask_function function corresponding to a 2D padding mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because + # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not + # vectorizable on accelerator devices + return padding_mask[batch_idx, kv_idx] + + return inner_mask + + +def packed_sequence_mask_function(packed_sequence_mask: ms.Tensor) -> Callable: + """ + This return the mask_function function corresponding to a 2D packed sequence mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx] + + return inner_mask + + +def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable: + """ + This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, + not start and end indices. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset) + + return inner_mask def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: """ @@ -131,6 +184,137 @@ def prepare_padding_mask( local_padding_mask = local_padding_mask[:, mask_indices] return local_padding_mask +def sdpa_mask_recent_torch( + batch_size: int, + cache_position: ms.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[ms.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + **kwargs, +) -> Optional[ms.Tensor]: + """ + Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that + the element should take part in the attention computation, and False that it should not. + This function can only be used with torch>=2.5, as the context manager is otherwise not available. + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`ms.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`ms.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + local_size (`int`, optional): + The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` + to try to skip mask creation if possible. + allow_is_causal_skip (`bool`, optional): + Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in + `torch.sdpa` instead. Default to `True`. + allow_torch_fix (`bool`, optional): + Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older + versions. We need an arg to skip it when using eager. By default `True`. + + + ## Creating a simple causal mask: + + To create the following causal mask: + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ■ ■ ■ ■ ⬚ + 4 ■ ■ ■ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=mint.arange(5), kv_length=5) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [ True, True, True, True, False], + [ True, True, True, True, True]]]]) + ``` + + ## Creating a sliding window mask: + + To create the following sliding window mask (`sliding_window=3`): + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ⬚ ■ ■ ■ ⬚ + 4 ⬚ ⬚ ■ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=mint.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3)) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, True, True, True, False], + [False, False, True, True, True]]]]) + ``` + + ## Creating a chunked attention mask + + To create the following chunked attention mask (`chunk_size=3`): + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ⬚ ⬚ ⬚ ■ ⬚ + 4 ⬚ ⬚ ⬚ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=mint.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3)) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, False, False, True, False], + [False, False, False, True, True]]]]) + ``` + + """ + q_length = cache_position.shape[0] + # Potentially pad the 2D mask, and slice it correctly + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) + + # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument + if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): + return None + + # Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` + # but without data-dependent slicing (i.e. torch.compile friendly) + kv_arange = mint.arange(kv_length, device=cache_position.device) + kv_arange += kv_offset + + # Potentially add the padding 2D mask + if padding_mask is not None: + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + + batch_arange = mint.arange(batch_size, device=cache_position.device) + head_arange = mint.arange(1, device=cache_position.device) + # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from + # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it + # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices + # with TransformGetItemToIndex(): + causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + + return causal_mask + def sdpa_mask_older_torch( batch_size: int, @@ -156,7 +340,7 @@ def sdpa_mask_older_torch( Args: batch_size (`int`): The batch size of the input sequence. - cache_position (`torch.Tensor`): + cache_position (`ms.Tensor`): A tensor of shape (query_length,) indicating the current indices of the input sequence elements. kv_length (`int`): The size that the key and value states will have during the attention computation. @@ -164,7 +348,7 @@ def sdpa_mask_older_torch( An optional offset to indicate at which first position the key and values states will refer to. mask_function (`Callable`): The mask factory function describing the mask pattern. - attention_mask (`torch.Tensor`, optional): + attention_mask (`ms.Tensor`, optional): The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) local_size (`int`, optional): The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` @@ -184,7 +368,7 @@ def sdpa_mask_older_torch( if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): return None - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` + # Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` # but without data-dependent slicing (i.e. torch.compile friendly) kv_arange = mint.arange(kv_length) kv_arange += kv_offset @@ -246,7 +430,7 @@ def _ignore_causal_mask_sdpa( # We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions # (especially mask_function indexing a tensor, such as the padding mask function) -sdpa_mask = sdpa_mask_older_torch +sdpa_mask = sdpa_mask_older_torch # TODO: use sdpa_mask_recent_torch orsdpa_mask_older_torch? def eager_mask( @@ -267,7 +451,7 @@ def eager_mask( Args: batch_size (`int`): The batch size of the input sequence. - cache_position (`torch.Tensor`): + cache_position (`ms.Tensor`): A tensor of shape (query_length,) indicating the current indices of the input sequence elements. kv_length (`int`): The size that the key and value states will have during the attention computation. @@ -275,7 +459,7 @@ def eager_mask( An optional offset to indicate at which first position the key and values states will refer to. mask_function (`Callable`): The mask factory function describing the mask pattern. - attention_mask (`torch.Tensor`, optional): + attention_mask (`ms.Tensor`, optional): The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) dtype (`torch.dtype`, optional): The dtype to use for the mask. By default, `torch.float32`. @@ -364,6 +548,32 @@ class AttentionMaskInterface(GeneralInterface): # Global AttentionMaskInterface shared by all models which do not need to overwrite any of the existing ones ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface() +def find_packed_sequence_indices(position_ids: ms.Tensor) -> ms.Tensor: + """ + Find the indices of the sequence to which each new query token in the sequence belongs when using packed + tensor format (i.e. several sequences packed in the same batch dimension). + + Args: + position_ids (`ms.Tensor`) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. + + Returns: + A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we + pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]]. + """ + # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So + # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result + # gives exactly the sequence indices + # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence + # cannot be part of the end of the first batch dim and the start of the 2nd one for example + first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1 + position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1) + packed_sequence_mask = (position_diff != 1).cumsum(-1) + + # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0` + # but it causes issues with export + return packed_sequence_mask + def _preprocess_mask_arguments( config: PretrainedConfig, @@ -371,6 +581,7 @@ def _preprocess_mask_arguments( attention_mask: Optional[Union[ms.Tensor, BlockMask]], cache_position: ms.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[ms.Tensor], layer_idx: Optional[int], ) -> tuple[bool, Optional[Union[ms.Tensor, BlockMask]], int, int]: """ @@ -390,6 +601,8 @@ def _preprocess_mask_arguments( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. + position_ids (`ms.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. layer_idx (`int`, optional): If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value length and offset. Indeed, for hybrid caches, different layers may return different lengths. @@ -399,6 +612,9 @@ def _preprocess_mask_arguments( Whether we should early exit mask creation, and return the mask as-is. attention_mask (`ms.Tensor` or `BlockMask` or `None`): The attention mask to either return immediately, or to use in downstream mask creation. + packed_sequence_mask (`ms.Tensor`, optional): + In case we detected packed sequence format, this is a tensor where each similar integer indicates that + the tokens belong to the same sequence. kv_length (`int`): The size that the key and value states will have during the attention computation. kv_offset (`int`): @@ -414,7 +630,7 @@ def _preprocess_mask_arguments( # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11 if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping: - return True, None, None, None + return True, None, None, None, None # Move the mask to correct device, and potentially switch dtype for efficiency if attention_mask is not None and attention_mask.ndim == 2: @@ -426,8 +642,17 @@ def _preprocess_mask_arguments( # Otherwise, the sizes are simply the input sizes else: kv_length, kv_offset = input_embeds.shape[1], 0 + # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None, + # and we don't have past_key_values, i.e. generally a training setup) + packed_sequence_mask = None + if position_ids is not None and attention_mask is None and past_key_values is None: + batch_size = input_embeds.shape[0] + # The position ids are sometimes just unsqueezed, without being expanded + if batch_size != position_ids.shape[0]: + position_ids = position_ids.expand(batch_size, -1) + packed_sequence_mask = find_packed_sequence_indices(position_ids) - return False, attention_mask, kv_length, kv_offset + return False, attention_mask, packed_sequence_mask, kv_length, kv_offset def create_causal_mask( @@ -436,6 +661,7 @@ def create_causal_mask( attention_mask: Optional[ms.Tensor], cache_position: ms.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[ms.Tensor] = None, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[ms.Tensor, BlockMask]]: @@ -457,6 +683,8 @@ def create_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. + position_ids (`ms.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. or_mask_function (`Callable`, optional): An optional mask function to combine with the causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the causal one, for example for image tokens handling. @@ -470,8 +698,8 @@ def create_causal_mask( else: layer_idx = 0 - early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx ) if early_exit: return attention_mask @@ -484,6 +712,10 @@ def create_causal_mask( # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True + # If we detected packing format + if packed_sequence_mask is not None: + mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) + allow_is_causal_skip = False # Allow slight deviations from causal mask if or_mask_function is not None: mask_factory_function = or_masks(mask_factory_function, or_mask_function) @@ -513,6 +745,7 @@ def create_sliding_window_causal_mask( attention_mask: Optional[ms.Tensor], cache_position: ms.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[ms.Tensor] = None, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[ms.Tensor, BlockMask]]: @@ -535,6 +768,8 @@ def create_sliding_window_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. + position_ids (`ms.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. or_mask_function (`Callable`, optional): An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. @@ -548,8 +783,8 @@ def create_sliding_window_causal_mask( else: layer_idx = 0 - early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx ) if early_exit: return attention_mask @@ -565,10 +800,18 @@ def create_sliding_window_causal_mask( # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True - + # If we detected packing format + if packed_sequence_mask is not None: + mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) + allow_is_causal_skip = False # Allow slight deviations from sliding causal mask - if or_mask_function is not None or and_mask_function is not None: - raise NotImplementedError("`or_mask_function` or `and_mask_function` arguments are not supported yet.") + if or_mask_function is not None: + mask_factory_function = or_masks(mask_factory_function, or_mask_function) + allow_is_causal_skip = False + if and_mask_function is not None: + mask_factory_function = and_masks(mask_factory_function, and_mask_function) + allow_is_causal_skip = False + # We now create the mask causal_mask = mask_interface( @@ -585,8 +828,299 @@ def create_sliding_window_causal_mask( ) return causal_mask +def create_chunked_causal_mask( + config: PretrainedConfig, + input_embeds: ms.Tensor, + attention_mask: Optional[ms.Tensor], + cache_position: ms.Tensor, + past_key_values: Optional[Cache], + position_ids: Optional[ms.Tensor] = None, + or_mask_function: Optional[Callable] = None, + and_mask_function: Optional[Callable] = None, +) -> Optional[Union[ms.Tensor, BlockMask]]: + """ + Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type + of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this + function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the + `modeling_xxx.py` files). + + Args: + config (`PretrainedConfig`): + The model config. + input_embeds (`ms.Tensor`): + The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the + batch size, query length and dtype. + attention_mask (`ms.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). + It can also be an already prepared 4D mask, in which case it is returned as-is. + cache_position (`ms.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + past_key_values (`Cache`, optional): + The past key values, if we use a cache. + position_ids (`ms.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. + or_mask_function (`Callable`, optional): + An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is + useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. + and_mask_function (`Callable`, optional): + An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is + useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. + """ + # If we have an HybridCache structure, here we want to create the mask for the sliding layers + if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: + layer_idx = past_key_values.is_sliding.index(True) + else: + layer_idx = 0 + + early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx + ) + if early_exit: + return attention_mask + + chunk_size = getattr(config, "attention_chunk_size", None) + if chunk_size is None: + raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set") + + # Raise if using chunked attention on context too large with FA2 + if config._attn_implementation == "flash_attention_2" and kv_length + kv_offset > chunk_size: + raise ValueError( + "Flash attention 2 cannot handle chunked attention, and the key-value length is larger than the chunk size so the " + "chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model" + ) + + batch_size, dtype = input_embeds.shape[0], input_embeds.dtype + mask_factory_function = chunked_causal_mask_function(chunk_size) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + + # Do not allow skip if we are compiling (this is to match BC) + # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it + allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True + + # If we detected packing format + if packed_sequence_mask is not None: + mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) + allow_is_causal_skip = False + + # Allow slight deviations from chunked causal mask + if or_mask_function is not None: + mask_factory_function = or_masks(mask_factory_function, or_mask_function) + allow_is_causal_skip = False + if and_mask_function is not None: + mask_factory_function = and_masks(mask_factory_function, and_mask_function) + allow_is_causal_skip = False + + # We now create the mask + causal_mask = mask_interface( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_factory_function, + attention_mask=attention_mask, + allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa + local_size=chunk_size, # Additional kwarg for sdpa + dtype=dtype, # Additional kwarg for eager + config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + ) + return causal_mask + LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = { "full_attention": create_causal_mask, "sliding_attention": create_sliding_window_causal_mask, + "chunked_attention": create_chunked_causal_mask, } + +def create_masks_for_generate( + config: PretrainedConfig, + input_embeds: ms.Tensor, + attention_mask: Optional[ms.Tensor], + cache_position: ms.Tensor, + past_key_values: Optional[Cache], + position_ids: Optional[ms.Tensor] = None, + or_mask_function: Optional[Callable] = None, + and_mask_function: Optional[Callable] = None, + **kwargs, +): + """ + This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in `generate` in order + to easily create the masks in advance, when we compile the forwards with Static caches. + + Args: + config (`PretrainedConfig`): + The model config. + input_embeds (`ms.Tensor`): + The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the + batch size, query length and dtype. + attention_mask (`ms.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). + It can also be an already prepared 4D mask, in which case it is returned as-is. + cache_position (`ms.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + past_key_values (`Cache`, optional): + The past key values, if we use a cache. + position_ids (`ms.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. + or_mask_function (`Callable`, optional): + An optional mask function to combine with the other mask function (by doing the union of both). This is + useful to easily overlay another mask on top of the causal one, for example for image tokens handling. + and_mask_function (`Callable`, optional): + An optional mask function to combine with the other mask function (by doing the intersection of both). This is + useful to easily overlay another mask on top of the causal one, for example for image tokens handling. + """ + # The attribute reside in the text config for composite models + effective_config = config.get_text_config() + # Prepare the mask args + mask_kwargs = { + "config": effective_config, + "input_embeds": input_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + "or_mask_function": or_mask_function, + "and_mask_function": and_mask_function, + } + + # If the attribute exist, we need several masks + if hasattr(effective_config, "layer_types"): + causal_masks = {} + for layer_pattern in set(effective_config.layer_types): + causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs) + return causal_masks + # In this case, all layers are sliding + elif getattr(effective_config, "sliding_window", None) is not None: + return create_sliding_window_causal_mask(**mask_kwargs) + # In this case, all layers are chunked + elif getattr(effective_config, "attention_chunk_size", None) is not None: + return create_chunked_causal_mask(**mask_kwargs) + # All layers use standard causal attention + return create_causal_mask(**mask_kwargs) + + +# Below are utilities to pretty-print the different masks +# Print the matrix with words as row labels +GREEN = "\033[92m" +YELLOW = "\033[93m" +RESET = "\033[0m" +BLACK_SQUARE = "■" +WHITE_SQUARE = "⬚" +GREY_SQUARE = "∙" +LOW_TRIANGLE = "⬕" +UPPER_TRIANGLE = "⬔" + + +def get_style(style): + if style == "majong": + BLACK_SQUARE = "🀞" # Full block (represents "on" or active) + BLACK_SQUARE = "🀙" # Full block (represents "on" or active) + WHITE_SQUARE = "🀆" # "▒" # Light shade (represents "off" or inactive) + LOW_TRIANGLE = "🀛" # Lower left triangle (stylized indication) + UPPER_TRIANGLE = "🀛" # Upper left triangle (stylized indication) + else: + BLACK_SQUARE = "█" # Full block (represents "on" or active) + WHITE_SQUARE = "░" # "▒" # Light shade (represents "off" or inactive) + LOW_TRIANGLE = "▙" # Lower left triangle (stylized indication)) + UPPER_TRIANGLE = "▜" # Upper left triangle (stylized indication) + + return BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE + + +# LOW_TRIANGLE = UPPER_TRIANGLE = "⟍" # Upper right triangle (stylized indication) + +YELLOW_SQUARE = f"{YELLOW}{BLACK_SQUARE}{RESET}" +GREEN_SQUARE = f"{GREEN}{BLACK_SQUARE}{RESET}" + + +def tensor_to_mask_visual(original_tensor: ms.Tensor, grid_size=(20, 40), style="majong") -> str: + BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE = get_style(style) + h, w = original_tensor.shape + max_h, max_w = grid_size + if not (h < max_h and w < max_w): + # Preserve aspect ratio within max grid size + aspect_ratio = 2 * w / h + if aspect_ratio > 1: + w = max_w + h = min(max_h, max(1, round(max_w / aspect_ratio))) + else: + h = max_h + w = max(1, round(max_h * aspect_ratio)) + + # Step 1: Rescale tensor by average pooling + tensor = original_tensor.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions + tensor = F.adaptive_avg_pool2d(tensor, output_size=(h, w))[0, 0] # Remove extra dims + else: + tensor = original_tensor + + # Step 3: Build the string representation + result = [] + for i in range(h): + row = "" + for j in range(w): + if tensor[i, j] == 1: + row += BLACK_SQUARE + elif tensor[i, j] == 0: + row += WHITE_SQUARE + else: + if j > 0: + if tensor[i, j - 1] == 1: + row += LOW_TRIANGLE + elif tensor[i, j - 1] == 0: + row += UPPER_TRIANGLE + else: + row += BLACK_SQUARE if tensor[i, j] == 1 else WHITE_SQUARE + else: + row += ( + BLACK_SQUARE + if tensor[i, j] == 1 + else ( + WHITE_SQUARE + if tensor[i, j] == 0 + else (UPPER_TRIANGLE if tensor[i, j + 1] == 1 else LOW_TRIANGLE) + ) + ) + result.append(row) + + return "\n".join(result) + + +class AttentionMask(ms.Tensor): + def __new__(cls, data, style=None): + # Create a new instance of AttentionMask as a Tensor + cls.style = style + return ms.Tensor._make_subclass(cls, data, require_grad=False) + + def __init__(self, data): + # You can initialize any additional metadata here if needed + pass + + def to_string(self, grid_size=(20, 40), limit=4): + """Returns a string representation of the block mask.""" + dense_mask = self + *batch_dims, num_rows, num_cols = dense_mask.shape + total_vis = [] + + for idx, batch_idx in enumerate(itertools.product(*[range(i) for i in batch_dims])): + if idx == limit: + total_vis.append("...") + total_vis.append("To print out more, set AttentionMask.to_string(limit=N)") + total_vis.append("You can also index (AttentionMask[batch, head]) to choose a specific batch or head") + break + block_vis = tensor_to_mask_visual(dense_mask[batch_idx], grid_size=grid_size, style=self.style) + total_vis.append(block_vis) + + total_vis.append(f"ms.Tensor(shape={tuple(self.shape)}, dtype={self.dtype})") + return "\n".join(total_vis) + + def __repr__(self): + return self.to_string() + + def __str__(self): + return self.to_string() + + @classmethod + def from_tensor(cls, tensor: ms.Tensor, style: Optional[str] = None) -> "AttentionMask": + res = cls(tensor) + res.style = style + return res From 00f2ba3b27132b9bcfa293c86fe2c17526c6f9e0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Aug 2025 13:13:40 +0800 Subject: [PATCH 36/94] update modeling_flash_attention_utils.py --- .../modeling_flash_attention_utils.py | 440 +++++++++++++++++- 1 file changed, 423 insertions(+), 17 deletions(-) diff --git a/mindone/transformers/modeling_flash_attention_utils.py b/mindone/transformers/modeling_flash_attention_utils.py index 4b4f5bd38d..b509236cd3 100644 --- a/mindone/transformers/modeling_flash_attention_utils.py +++ b/mindone/transformers/modeling_flash_attention_utils.py @@ -14,13 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect +import warnings +import os from typing import Optional, TypedDict from transformers.utils import logging import mindspore as ms - +from mindspore import mint +import mindspore.mint.functional as F logger = logging.get_logger(__name__) @@ -33,22 +36,425 @@ def is_flash_attn_available(): return False -class FlashAttentionKwargs(TypedDict, total=False): +def _index_first_axis(tensor: ms.Tensor, indices: ms.Tensor) -> ms.Tensor: + reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:]) + return reshaped[indices] + + +def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + FA3-compatible unpad_input function. + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. """ - Keyword arguments for Flash Attention with Compile. + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=ms.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=ms.int32) + indices = mint.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(mint.cumsum(seqlens_in_batch, dim=0, dtype=ms.int32), (1, 0)) + + return ( + _index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + - Attributes: - cu_seq_lens_q (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for query state. - cu_seq_lens_k (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for key state. - max_length_q (`int`, *optional*): - Maximum sequence length for query state. - max_length_k (`int`, *optional*): - Maximum sequence length for key state. +def _fa3_pad_input(hidden_states, indices, batch, seqlen): """ + FA3-compatible pad_input function. + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = mint.zeros((batch * seqlen, *dim), dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def _get_unpad_data(attention_mask: ms.Tensor) -> tuple[ms.Tensor, ms.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + Arguments: + attention_mask (`ms.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + Return: + indices (`ms.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`ms.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=ms.int32) + indices = mint.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(mint.cumsum(seqlens_in_batch, dim=0, dtype=ms.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: ms.Tensor, + key_layer: ms.Tensor, + value_layer: ms.Tensor, + attention_mask: ms.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + Arguments: + query_layer (`ms.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`ms.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`ms.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`ms.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + Return: + query_layer (`ms.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`ms.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`ms.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`ms.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = _index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = mint.arange( + batch_size + 1, dtype=ms.int32, + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def _prepare_from_posids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cumulative lengths of each examples in the batch will be extracted from position_ids. + NOTE: ideally cumulative lengths should be prepared at the data collator stage + Arguments: + query (`ms.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`ms.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`ms.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`ms.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + Return: + query (`ms.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`ms.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`ms.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`ms.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + + position_ids = position_ids.flatten() + indices_q = mint.arange(position_ids.size(0), dtype=ms.int32) + + cu_seq_lens = mint.cat( + ( + indices_q[position_ids == 0], + ms.Tensor(position_ids.size(), dtype=ms.int32), + ) + ) + # NOTE: With torch compile, this will cause a graph break if you don't set + # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call + # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. + # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` + # requires `max_length_q`, `max_length_k` to be passed as `int` and not `ms.Tensor`. + # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 + # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing + # for some models (e.g. qwen2-vl). + max_length = cu_seq_lens.diff().max().item() + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): + warnings.warn( + "prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids", + FutureWarning, + ) + return _prepare_from_posids(query, key, value, position_ids) + + +def fa_peft_integration_check(q, k, v, target_dtype: Optional[ms.Type] = None): + if target_dtype and q.dtype == ms.float32: + logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.") + q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) + return q, k, v + +#TODO: fix this flash attention 2 and 3 +def _lazy_imports(impl: Optional[str]): + # returns funcs and pad/unpad based on impl + is_fa2 = is_flash_attn_2_available() or is_torch_npu_available() + is_fa3 = is_flash_attn_3_available() + if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input + + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False + + except ImportError as e: + if not globals().get("use_remote_fa2", None): + use_remote_fa2 = ( + input( + "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? " + ) + .strip() + .lower() + ) + globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"} + if globals()["use_remote_fa2"]: + if not is_kernels_available(): + raise ImportError("You need to install kernels: `pip install kernels`") + from kernels import get_kernel + + impl = get_kernel("kernels-community/flash-attn") + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return ( + getattr(impl, "flash_attn_func", None), + getattr(impl, "flash_attn_varlen_func"), + pad_input, + unpad_input, + True, + ) + + else: + raise ImportError( + "Failed to import flash attention 2, please install it or use another implementation." + ) from e + if impl == "flash_attention_3" or (impl is None and is_fa3): + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func + + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True + else: + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return ( + getattr(impl, "flash_attn_func", None), + getattr(impl, "flash_attn_varlen_func"), + pad_input, + unpad_input, + True, + ) + + +_flash_supports_window = None + + +def is_flash_attn_available(): + # return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() + return True + + +def flash_attn_supports_top_left_mask(): + if is_flash_attn_3_available(): + return False + if is_flash_attn_2_available(): + return not is_flash_attn_greater_or_equal_2_10() + + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + + return is_npu_fa2_top_left_aligned_causal_mask() + + +class FlashAttentionKwargs(TypedDict, total=False): + cumulative_seqlens_q: Optional[ms.Tensor] + cumulative_seqlens_k: Optional[ms.Tensor] + + +def _flash_attention_forward( + query_states: ms.Tensor, + key_states: ms.Tensor, + value_states: ms.Tensor, + attention_mask: Optional[ms.Tensor], + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[ms.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + cu_seq_lens_q: Optional[ms.Tensor] = None, + cu_seq_lens_k: Optional[ms.Tensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, + target_dtype: Optional[ms.Type] = None, + implementation: Optional[str] = None, + **kwargs, +): + if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")): + flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) + globals()["_flash_fn"] = flash_fn + globals()["_flash_varlen_fn"] = flash_varlen_fn + globals()["_pad_fn"] = pad_fn + globals()["_unpad_fn"] = unpad_fn + globals()["_is_fa3"] = is_fa3 + flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters + globals()["_flash_supports_window"] = flash_supports_window + else: + flash_fn = globals()["_flash_fn"] + flash_varlen_fn = globals()["_flash_varlen_fn"] + pad_fn = globals()["_pad_fn"] + unpad_fn = globals()["_unpad_fn"] + is_fa3 = globals()["_is_fa3"] + flash_supports_window = globals()["_flash_supports_window"] + + causal = is_causal and not (use_top_left_mask and query_length == 1) + use_sw = ( + (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} + if not is_fa3: + flash_kwargs["dropout_p"] = dropout + if is_flash_attn_greater_or_equal("2.4.1"): + det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = det + if softcap is not None: + flash_kwargs["softcap"] = softcap + + query_states, key_states, value_states = fa_peft_integration_check( + query_states, key_states, value_states, target_dtype + ) + use_mask = position_ids is not None or all( + k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k] + ) + if attention_mask is not None: + q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( + query_states, key_states, value_states, attention_mask, query_length, unpad_fn + ) + + out_unpad = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_q.to(ms.int32), + cu_seqlens_k=cu_k.to(ms.int32), + max_seqlen_q=mq, + max_seqlen_k=mk, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] + out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) + elif use_mask: + if cu_seq_lens_q is None or cu_seq_lens_k is None: + if position_ids is None: + raise ValueError( + "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed." + ) + q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids( + query_states, key_states, value_states, position_ids + ) + else: + q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + mq, mk = max_length_q, max_length_k + cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k + + out = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_q.to(ms.int32), + cu_seqlens_k=cu_k.to(ms.int32), + max_seqlen_q=mq, + max_seqlen_k=mk, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + if isinstance(out, tuple): + out = out[0] + out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) + else: + out = flash_fn( + query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return out[0] if isinstance(out, tuple) else out + - cu_seq_lens_q: Optional[ms.Tensor] - cu_seq_lens_k: Optional[ms.Tensor] - max_length_q: Optional[int] - max_length_k: Optional[int] From d0b34fb96c5a12d62591731b23c07f7bf1d5dcfa Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Aug 2025 13:26:46 +0800 Subject: [PATCH 37/94] update modeling_outputs.py --- mindone/transformers/modeling_outputs.py | 248 ++++++++++------------- 1 file changed, 106 insertions(+), 142 deletions(-) diff --git a/mindone/transformers/modeling_outputs.py b/mindone/transformers/modeling_outputs.py index 252f1a04be..8b4be73ec8 100644 --- a/mindone/transformers/modeling_outputs.py +++ b/mindone/transformers/modeling_outputs.py @@ -20,6 +20,7 @@ from transformers.utils import ModelOutput +from .cache_utils import Cache, EncoderDecoderCache import mindspore as ms @@ -44,7 +45,7 @@ class BaseModelOutput(ModelOutput): heads. """ - last_hidden_state: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -64,7 +65,7 @@ class BaseModelOutputWithNoAttention(ModelOutput): Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. """ - last_hidden_state: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None @@ -94,8 +95,8 @@ class BaseModelOutputWithPooling(ModelOutput): heads. """ - last_hidden_state: ms.Tensor = None - pooler_output: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None + pooler_output: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -117,8 +118,8 @@ class BaseModelOutputWithPoolingAndNoAttention(ModelOutput): Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. """ - last_hidden_state: ms.Tensor = None - pooler_output: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None + pooler_output: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None @@ -133,11 +134,8 @@ class BaseModelOutputWithPast(ModelOutput): If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` @@ -155,8 +153,8 @@ class BaseModelOutputWithPast(ModelOutput): heads. """ - last_hidden_state: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + last_hidden_state: Optional[ms.Tensor] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -189,7 +187,7 @@ class BaseModelOutputWithCrossAttentions(ModelOutput): weighted average in the cross-attention heads. """ - last_hidden_state: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -226,21 +224,18 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. """ - last_hidden_state: ms.Tensor = None - pooler_output: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None + pooler_output: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + past_key_values: Optional[Cache] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -256,11 +251,8 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` @@ -302,9 +294,8 @@ class MoECausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -331,12 +322,12 @@ class MoECausalLMOutputWithPast(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None - z_loss: ms.Tensor = None - aux_loss: ms.Tensor = None + z_loss: Optional[ms.Tensor] = None + aux_loss: Optional[ms.Tensor] = None router_logits: Optional[Tuple[ms.Tensor]] = None @@ -367,7 +358,7 @@ class MoEModelOutput(ModelOutput): loss and the z_loss for Mixture of Experts models. """ - last_hidden_state: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None router_probs: Optional[Tuple[ms.Tensor]] = None @@ -381,11 +372,8 @@ class MoeModelOutputWithPast(ModelOutput): Args: last_hidden_state (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` @@ -409,8 +397,8 @@ class MoeModelOutputWithPast(ModelOutput): loss for Mixture of Experts models. """ - last_hidden_state: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + last_hidden_state: Optional[ms.Tensor] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None router_logits: Optional[Tuple[ms.Tensor]] = None @@ -438,9 +426,8 @@ class MoeCausalLMOutputWithPast(ModelOutput): Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary loss for Mixture of Experts models. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -459,8 +446,8 @@ class MoeCausalLMOutputWithPast(ModelOutput): loss: Optional[ms.Tensor] = None aux_loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None router_logits: Optional[Tuple[ms.Tensor]] = None @@ -478,11 +465,8 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` @@ -572,8 +556,8 @@ class Seq2SeqModelOutput(ModelOutput): self-attention heads. """ - last_hidden_state: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + last_hidden_state: Optional[ms.Tensor] = None + past_key_values: Optional[Cache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -594,10 +578,8 @@ class Seq2SeqMoEModelOutput(ModelOutput): If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -642,8 +624,8 @@ class Seq2SeqMoEModelOutput(ModelOutput): modules. """ - last_hidden_state: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + last_hidden_state: Optional[ms.Tensor] = None + past_key_values: Optional[EncoderDecoderCache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None decoder_router_logits: Optional[Tuple[ms.Tensor]] = None @@ -678,7 +660,7 @@ class CausalLMOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None + logits: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -693,9 +675,8 @@ class CausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -713,8 +694,8 @@ class CausalLMOutputWithPast(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -746,18 +727,16 @@ class CausalLMOutputWithCrossAttentions(ModelOutput): Cross attentions weights after the attention softmax, used to compute the weighted average in the cross-attention heads. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `ms.Tensor` tuples of length `config.n_layers`, with each tuple containing the cached key, - value states of the self-attention and the cross-attention layers if model is used in encoder-decoder - setting. Only relevant if `config.is_decoder = True`. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -773,9 +752,8 @@ class SequenceClassifierOutputWithPast(ModelOutput): Classification (or regression if config.num_labels==1) loss. logits (`ms.Tensor` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -793,8 +771,8 @@ class SequenceClassifierOutputWithPast(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -823,7 +801,7 @@ class MaskedLMOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None + logits: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -838,10 +816,8 @@ class Seq2SeqLMOutput(ModelOutput): Language modeling loss. logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -878,8 +854,8 @@ class Seq2SeqLMOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[EncoderDecoderCache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -898,10 +874,8 @@ class Seq2SeqMoEOutput(ModelOutput): Language modeling loss. logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -947,12 +921,12 @@ class Seq2SeqMoEOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - encoder_z_loss: ms.Tensor = None - decoder_z_loss: ms.Tensor = None - encoder_aux_loss: ms.Tensor = None - decoder_aux_loss: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + logits: Optional[ms.Tensor] = None + encoder_z_loss: Optional[ms.Tensor] = None + decoder_z_loss: Optional[ms.Tensor] = None + encoder_aux_loss: Optional[ms.Tensor] = None + decoder_aux_loss: Optional[ms.Tensor] = None + past_key_values: Optional[EncoderDecoderCache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None decoder_router_logits: Optional[Tuple[ms.Tensor]] = None @@ -988,7 +962,7 @@ class NextSentencePredictorOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None + logits: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1017,7 +991,7 @@ class SequenceClassifierOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None + logits: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1032,10 +1006,8 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): Classification (or regression if config.num_labels==1) loss. logits (`ms.Tensor` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -1072,8 +1044,8 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[EncoderDecoderCache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1108,7 +1080,7 @@ class MultipleChoiceModelOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None + logits: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1137,7 +1109,7 @@ class TokenClassifierOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None + logits: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1168,8 +1140,8 @@ class QuestionAnsweringModelOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - start_logits: ms.Tensor = None - end_logits: ms.Tensor = None + start_logits: Optional[ms.Tensor] = None + end_logits: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1186,10 +1158,8 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): Span-start scores (before SoftMax). end_logits (`ms.Tensor` of shape `(batch_size, sequence_length)`): Span-end scores (before SoftMax). - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -1226,9 +1196,9 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - start_logits: ms.Tensor = None - end_logits: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + start_logits: Optional[ms.Tensor] = None + end_logits: Optional[ms.Tensor] = None + past_key_values: Optional[EncoderDecoderCache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1270,7 +1240,7 @@ class SemanticSegmenterOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None + logits: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1349,7 +1319,7 @@ class DepthEstimatorOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - predicted_depth: ms.Tensor = None + predicted_depth: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1377,7 +1347,7 @@ class ImageSuperResolutionOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - reconstruction: ms.Tensor = None + reconstruction: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1405,8 +1375,8 @@ class Wav2Vec2BaseModelOutput(ModelOutput): heads. """ - last_hidden_state: ms.Tensor = None - extract_features: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None + extract_features: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1437,8 +1407,8 @@ class XVectorOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - logits: ms.Tensor = None - embeddings: ms.Tensor = None + logits: Optional[ms.Tensor] = None + embeddings: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1465,7 +1435,7 @@ class BackboneOutput(ModelOutput): heads. """ - feature_maps: Tuple[ms.Tensor] = None + feature_maps: Optional[Tuple[ms.Tensor]] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1500,8 +1470,8 @@ class BaseModelOutputWithPoolingAndProjection(ModelOutput): Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder. """ - last_hidden_state: ms.Tensor = None - pooler_output: ms.Tensor = None + last_hidden_state: Optional[ms.Tensor] = None + pooler_output: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None projection_state: Optional[Tuple[ms.Tensor]] = None @@ -1517,10 +1487,8 @@ class Seq2SeqSpectrogramOutput(ModelOutput): Spectrogram generation loss. spectrogram (`ms.Tensor` of shape `(batch_size, sequence_length, num_bins)`): The predicted spectrogram. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -1557,8 +1525,8 @@ class Seq2SeqSpectrogramOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - spectrogram: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + spectrogram: Optional[ms.Tensor] = None + past_key_values: Optional[EncoderDecoderCache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1579,10 +1547,8 @@ class Seq2SeqTSModelOutput(ModelOutput): If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -1626,8 +1592,8 @@ class Seq2SeqTSModelOutput(ModelOutput): Static features of each time series' in a batch which are copied to the covariates at inference time. """ - last_hidden_state: ms.Tensor = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + last_hidden_state: Optional[ms.Tensor] = None + past_key_values: Optional[EncoderDecoderCache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1650,10 +1616,8 @@ class Seq2SeqTSPredictionOutput(ModelOutput): Distributional loss. params (`ms.Tensor` of shape `(batch_size, num_samples, num_params)`): Parameters of the chosen distribution. - past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -1699,7 +1663,7 @@ class Seq2SeqTSPredictionOutput(ModelOutput): loss: Optional[ms.Tensor] = None params: Optional[Tuple[ms.Tensor]] = None - past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None + past_key_values: Optional[EncoderDecoderCache] = None decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None @@ -1722,7 +1686,7 @@ class SampleTSPredictionOutput(ModelOutput): Sampled values from the chosen distribution. """ - sequences: ms.Tensor = None + sequences: Optional[ms.Tensor] = None @dataclass @@ -1748,7 +1712,7 @@ class MaskedImageModelingOutput(ModelOutput): """ loss: Optional[ms.Tensor] = None - reconstruction: ms.Tensor = None + reconstruction: Optional[ms.Tensor] = None hidden_states: Optional[Tuple[ms.Tensor, ...]] = None attentions: Optional[Tuple[ms.Tensor, ...]] = None From f75b06a8ebb77c1404f063fcc98cfdfb257c3b85 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Aug 2025 13:50:32 +0800 Subject: [PATCH 38/94] fix pre-commit errors --- mindone/transformers/image_processing_base.py | 1 - .../image_processing_utils_fast.py | 21 ++++--- mindone/transformers/image_transforms.py | 8 +-- mindone/transformers/image_utils.py | 10 ++- .../integrations/sdpa_attention.py | 1 + mindone/transformers/masking_utils.py | 13 +++- .../modeling_flash_attention_utils.py | 44 ++++--------- mindone/transformers/modeling_outputs.py | 3 +- mindone/transformers/processing_utils.py | 18 +++--- mindone/transformers/tokenization_utils.py | 11 ++-- .../transformers/tokenization_utils_base.py | 63 +++++++------------ mindone/transformers/utils/backbone_utils.py | 2 +- mindone/transformers/utils/generic.py | 21 +++---- mindone/transformers/utils/import_utils.py | 4 ++ mindone/transformers/video_utils.py | 15 ++--- 15 files changed, 105 insertions(+), 130 deletions(-) diff --git a/mindone/transformers/image_processing_base.py b/mindone/transformers/image_processing_base.py index d4dfd78236..c66c3b618b 100644 --- a/mindone/transformers/image_processing_base.py +++ b/mindone/transformers/image_processing_base.py @@ -383,7 +383,6 @@ def get_image_processor_dict( f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}" ) - return image_processor_dict, kwargs @classmethod diff --git a/mindone/transformers/image_processing_utils_fast.py b/mindone/transformers/image_processing_utils_fast.py index 8af174e6e1..dcca36084e 100644 --- a/mindone/transformers/image_processing_utils_fast.py +++ b/mindone/transformers/image_processing_utils_fast.py @@ -16,14 +16,14 @@ # limitations under the License. from collections.abc import Iterable +from copy import deepcopy from functools import lru_cache, partial from typing import Any, Optional, TypedDict, Union -from copy import deepcopy + import numpy as np from PIL import Image +from transformers.utils import auto_docstring, logging -from mindspore import mint -import mindspore as ms from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from .image_transforms import ( @@ -49,14 +49,14 @@ ) from .processing_utils import Unpack from .utils import TensorType, is_mindspore_available, is_mindspore_tensor, is_vision_available -from transformers.utils import auto_docstring, logging - if is_vision_available(): from .image_utils import PILImageResampling if is_mindspore_available(): import mindspore as ms + from mindspore import mint + import mindspore.mint.functional as F from mindspore.dataset import vision from mindspore.dataset.vision import Inter as InterpolationMode @@ -176,7 +176,6 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False): input_data_format: Optional[Union[str, ChannelDimension]] - @auto_docstring class BaseImageProcessorFast(BaseImageProcessor): resample = None @@ -222,7 +221,6 @@ def __init__( # get valid kwargs names self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys()) - def resize( self, image: "ms.Tensor", @@ -289,7 +287,7 @@ def compile_friendly_resize( A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor. """ raise NotImplementedError("This method is not implemented for mindspore") - + def rescale( self, image: "ms.Tensor", @@ -526,7 +524,7 @@ def _prepare_image_like_inputs( processed_images = [process_image_partial(img) for img in images] return processed_images - + def _further_process_kwargs( self, size: Optional[SizeDict] = None, @@ -626,7 +624,9 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module. # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`. kwargs["interpolation"] = ( - pil_mindspore_interpolation_mapping[resample] if isinstance(resample, (int, PILImageResampling)) else resample + pil_mindspore_interpolation_mapping[resample] + if isinstance(resample, (int, PILImageResampling)) + else resample ) # Pop kwargs that are not needed in _preprocess @@ -706,6 +706,7 @@ def to_dict(self): encoder_dict.pop("_valid_kwargs_names", None) return encoder_dict + class SemanticSegmentationMixin: def post_process_semantic_segmentation(self, outputs, target_sizes: list[tuple] = None): """ diff --git a/mindone/transformers/image_transforms.py b/mindone/transformers/image_transforms.py index 8fa9b969e4..1ea2602c44 100644 --- a/mindone/transformers/image_transforms.py +++ b/mindone/transformers/image_transforms.py @@ -469,7 +469,6 @@ def center_crop( """ requires_backends(center_crop, ["vision"]) - if not isinstance(image, np.ndarray): raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") @@ -521,7 +520,6 @@ def center_crop( new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)] new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST) - return new_image @@ -802,6 +800,7 @@ def _cast_tensor_to_float(x): return x return x.float() + def _group_images_by_shape(nested_images, is_nested: bool = False): """Helper function to flatten a single level of nested image structures and group by shape.""" grouped_images = defaultdict(list) @@ -842,13 +841,12 @@ def _reconstruct_nested_structure(indices, processed_images): return result + def group_images_by_shape( images: Union[list["ms.Tensor"], "ms.Tensor"], disable_grouping: bool, is_nested: bool = False, -) -> tuple[ - dict[tuple[int, int], list["ms.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]] -]: +) -> tuple[dict[tuple[int, int], list["ms.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]: """ Groups images by shape. Returns a dictionary with the shape as key and a list of images with that shape as value, diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py index a3c5d07da1..666d1686e0 100644 --- a/mindone/transformers/image_utils.py +++ b/mindone/transformers/image_utils.py @@ -21,12 +21,14 @@ from dataclasses import dataclass from io import BytesIO from typing import TYPE_CHECKING, Optional, Union -from mindspore import mint + import numpy as np import requests from packaging import version from transformers.utils import logging +from mindspore import mint + from .utils import ( ExplicitEnum, is_mindspore_available, @@ -44,6 +46,7 @@ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ) + if is_vision_available(): import PIL.Image import PIL.ImageOps @@ -97,7 +100,6 @@ class AnnotionFormat(ExplicitEnum): COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value - AnnotationType = dict[str, Union[int, str, list[dict]]] @@ -130,6 +132,7 @@ def is_valid_image(img): def is_valid_list_of_images(images: list): return images and all(is_valid_image(image) for image in images) + def concatenate_list(input_list): if isinstance(input_list[0], list): return [item for sublist in input_list for item in sublist] @@ -138,6 +141,7 @@ def concatenate_list(input_list): elif isinstance(input_list[0], mint.Tensor): return mint.cat(input_list, dim=0) + def valid_images(imgs): # If we have an list of images, make sure every image is valid if isinstance(imgs, (list, tuple)): @@ -291,6 +295,7 @@ def to_numpy_array(img) -> np.ndarray: return np.array(img) return to_numpy(img) + def pil_to_tensor(image, is_normalize=True): """ Pillow image to mindspore tensor @@ -349,6 +354,7 @@ def infer_channel_dimension_format( return ChannelDimension.LAST raise ValueError("Unable to infer channel dimension format") + def get_channel_dimension_axis( image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None ) -> int: diff --git a/mindone/transformers/integrations/sdpa_attention.py b/mindone/transformers/integrations/sdpa_attention.py index 0c1597eac6..604c888884 100644 --- a/mindone/transformers/integrations/sdpa_attention.py +++ b/mindone/transformers/integrations/sdpa_attention.py @@ -21,6 +21,7 @@ def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand((batch, num_key_value_heads, n_rep, slen, head_dim)) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + def use_gqa_in_sdpa(attention_mask: Optional[ms.Tensor], key: ms.Tensor) -> bool: # GQA is not supported yet. return False diff --git a/mindone/transformers/masking_utils.py b/mindone/transformers/masking_utils.py index 7f8dfe3d93..d4e60a204f 100644 --- a/mindone/transformers/masking_utils.py +++ b/mindone/transformers/masking_utils.py @@ -22,7 +22,7 @@ import mindspore as ms from mindspore import mint - +import mindspore.mint.functional as F from .cache_utils import Cache from .modeling_attn_mask_utils import dtype_to_min from .utils.generic import GeneralInterface @@ -76,6 +76,7 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return inner_mask + def chunked_overlay(chunk_size: int) -> Callable: """ This is an overlay depicting a chuned attention pattern. Add it on top of a causal mask for a proper chunked @@ -87,12 +88,14 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return inner_mask + def sliding_window_causal_mask_function(sliding_window: int) -> Callable: """ This return the mask_function function to create a sliding window mask. """ return and_masks(sliding_window_overlay(sliding_window), causal_mask_function) + def chunked_causal_mask_function(chunk_size: int) -> Callable: """ This return the mask_function function to create a chunked attention mask. @@ -136,6 +139,7 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return inner_mask + def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: """ Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over @@ -184,6 +188,7 @@ def prepare_padding_mask( local_padding_mask = local_padding_mask[:, mask_indices] return local_padding_mask + def sdpa_mask_recent_torch( batch_size: int, cache_position: ms.Tensor, @@ -548,6 +553,7 @@ class AttentionMaskInterface(GeneralInterface): # Global AttentionMaskInterface shared by all models which do not need to overwrite any of the existing ones ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface() + def find_packed_sequence_indices(position_ids: ms.Tensor) -> ms.Tensor: """ Find the indices of the sequence to which each new query token in the sequence belongs when using packed @@ -567,7 +573,7 @@ def find_packed_sequence_indices(position_ids: ms.Tensor) -> ms.Tensor: # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence # cannot be part of the end of the first batch dim and the start of the 2nd one for example first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1 - position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1) + position_diff = mint.diff(position_ids, prepend=first_dummy_value, dim=-1) packed_sequence_mask = (position_diff != 1).cumsum(-1) # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0` @@ -812,7 +818,6 @@ def create_sliding_window_causal_mask( mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False - # We now create the mask causal_mask = mask_interface( batch_size=batch_size, @@ -828,6 +833,7 @@ def create_sliding_window_causal_mask( ) return causal_mask + def create_chunked_causal_mask( config: PretrainedConfig, input_embeds: ms.Tensor, @@ -932,6 +938,7 @@ def create_chunked_causal_mask( "chunked_attention": create_chunked_causal_mask, } + def create_masks_for_generate( config: PretrainedConfig, input_embeds: ms.Tensor, diff --git a/mindone/transformers/modeling_flash_attention_utils.py b/mindone/transformers/modeling_flash_attention_utils.py index b509236cd3..a133452e84 100644 --- a/mindone/transformers/modeling_flash_attention_utils.py +++ b/mindone/transformers/modeling_flash_attention_utils.py @@ -15,15 +15,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import warnings import os +import warnings from typing import Optional, TypedDict from transformers.utils import logging import mindspore as ms -from mindspore import mint import mindspore.mint.functional as F +from mindspore import mint + logger = logging.get_logger(__name__) @@ -173,7 +174,8 @@ def _upad_input( elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = mint.arange( - batch_size + 1, dtype=ms.int32, + batch_size + 1, + dtype=ms.int32, ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) @@ -260,7 +262,8 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[ms.Type] = None): q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) return q, k, v -#TODO: fix this flash attention 2 and 3 + +# TODO: fix this flash attention 2 and 3 def _lazy_imports(impl: Optional[str]): # returns funcs and pad/unpad based on impl is_fa2 = is_flash_attn_2_available() or is_torch_npu_available() @@ -283,20 +286,7 @@ def _lazy_imports(impl: Optional[str]): ) globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"} if globals()["use_remote_fa2"]: - if not is_kernels_available(): - raise ImportError("You need to install kernels: `pip install kernels`") - from kernels import get_kernel - - impl = get_kernel("kernels-community/flash-attn") - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input - return ( - getattr(impl, "flash_attn_func", None), - getattr(impl, "flash_attn_varlen_func"), - pad_input, - unpad_input, - True, - ) - + raise NotImplementedError("Remote flash attention 2 is not supported yet.") else: raise ImportError( "Failed to import flash attention 2, please install it or use another implementation." @@ -326,15 +316,7 @@ def is_flash_attn_available(): def flash_attn_supports_top_left_mask(): - if is_flash_attn_3_available(): - return False - if is_flash_attn_2_available(): - return not is_flash_attn_greater_or_equal_2_10() - - from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask - - return is_npu_fa2_top_left_aligned_causal_mask() - + raise NotImplementedError("flash_attn_supports_top_left_mask is not supported yet.") class FlashAttentionKwargs(TypedDict, total=False): cumulative_seqlens_q: Optional[ms.Tensor] @@ -387,9 +369,9 @@ def _flash_attention_forward( flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} if not is_fa3: flash_kwargs["dropout_p"] = dropout - if is_flash_attn_greater_or_equal("2.4.1"): - det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - flash_kwargs["deterministic"] = det + # if is_flash_attn_greater_or_equal("2.4.1"): + # det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + # flash_kwargs["deterministic"] = det if softcap is not None: flash_kwargs["softcap"] = softcap @@ -456,5 +438,3 @@ def _flash_attention_forward( ) return out[0] if isinstance(out, tuple) else out - - diff --git a/mindone/transformers/modeling_outputs.py b/mindone/transformers/modeling_outputs.py index 8b4be73ec8..134422d929 100644 --- a/mindone/transformers/modeling_outputs.py +++ b/mindone/transformers/modeling_outputs.py @@ -20,9 +20,10 @@ from transformers.utils import ModelOutput -from .cache_utils import Cache, EncoderDecoderCache import mindspore as ms +from .cache_utils import Cache, EncoderDecoderCache + @dataclass class BaseModelOutput(ModelOutput): diff --git a/mindone/transformers/processing_utils.py b/mindone/transformers/processing_utils.py index 170a83c902..362e54d2cc 100644 --- a/mindone/transformers/processing_utils.py +++ b/mindone/transformers/processing_utils.py @@ -34,12 +34,13 @@ import typing_extensions from huggingface_hub.errors import EntryNotFoundError from transformers.dynamic_module_utils import custom_object_save +from transformers.utils.chat_template_utils import render_jinja_template from .audio_utils import load_audio from .feature_extraction_utils import BatchFeature from .image_utils import ChannelDimension, is_vision_available, load_image -from transformers.utils.chat_template_utils import render_jinja_template from .video_utils import VideoMetadata, load_video + if is_vision_available(): from .image_utils import PILImageResampling @@ -68,9 +69,9 @@ list_repo_templates, logging, ) +from transformers.utils.deprecation import deprecate_kwarg from .utils import TensorType -from transformers.utils.deprecation import deprecate_kwarg logger = logging.get_logger(__name__) @@ -417,6 +418,7 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False): continue_final_message: Optional[bool] = False return_assistant_tokens_mask: Optional[bool] = False + class ChatTemplateLoadKwargs(TypedDict, total=False): """ Keyword arguments used to load multimodal data in processor chat templates. @@ -443,6 +445,7 @@ def sample_indices_fn(num_frames, fps, metadata, **kwargs): sampling_rate: Optional[int] = 16_000 load_audio_from_video: Optional[bool] = False + class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False): """ Keyword arguments for processor's `apply_chat_template`. @@ -456,6 +459,7 @@ class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateK tokenize: Optional[bool] = False return_dict: Optional[bool] = False + class AllKwargsForChatTemplate( TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs ): @@ -521,7 +525,6 @@ def __init__(self, *args, **kwargs): # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None: proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value) - # Sanitize args and kwargs for key in kwargs: @@ -1441,9 +1444,7 @@ def apply_chat_template( elif self.chat_template is not None: chat_template = self.chat_template else: - raise ValueError( - "Cannot use apply_chat_template because this processor does not have a chat template." - ) + raise ValueError("Cannot use apply_chat_template because this processor does not have a chat template.") else: if isinstance(self.chat_template, dict) and chat_template in self.chat_template: # It's the name of a template, not a full template string @@ -1615,8 +1616,7 @@ def apply_chat_template( end_pos = bisect.bisect_left(offset_starts, assistant_end_char) if not ( - start_pos >= 0 - and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1] + start_pos >= 0 and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1] ): # start_token is out of bounds maybe due to truncation. continue @@ -1648,7 +1648,6 @@ def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens """ return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) - def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", modalities: list[str]): """ Checks that number of special tokens in text and processed text is same. The count can be different @@ -1666,6 +1665,7 @@ def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", "Likely due to `truncation='max_length'`. Please disable truncation or increase `max_length`." ) + ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) if ProcessorMixin.push_to_hub.__doc__ is not None: ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( diff --git a/mindone/transformers/tokenization_utils.py b/mindone/transformers/tokenization_utils.py index 2f4a2f8a37..40347322fd 100644 --- a/mindone/transformers/tokenization_utils.py +++ b/mindone/transformers/tokenization_utils.py @@ -43,7 +43,6 @@ ) from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging - logger = logging.get_logger(__name__) # Slow tokenizers are saved in a vocabulary plus three separated files @@ -559,9 +558,7 @@ def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_to else: # very important for fast and slow equivalence! is_special = token in self.all_special_tokens or special_tokens - token = AddedToken( - token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special - ) + token = AddedToken(token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special) elif special_tokens: # doing token.special=True changes the normalization! will fix in rust # this is important and the only reason why the AddedTokens in each class are normalized by default @@ -1037,10 +1034,12 @@ def get_special_tokens_mask( return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) @overload - def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ... + def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: + ... @overload - def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: ... + def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: + ... def convert_ids_to_tokens( self, ids: Union[int, list[int]], skip_special_tokens: bool = False diff --git a/mindone/transformers/tokenization_utils_base.py b/mindone/transformers/tokenization_utils_base.py index 78788edf6b..46697b8adf 100644 --- a/mindone/transformers/tokenization_utils_base.py +++ b/mindone/transformers/tokenization_utils_base.py @@ -35,18 +35,7 @@ import numpy as np from packaging import version - -from . import __version__ from transformers.dynamic_module_utils import custom_object_save -from .utils import ( - ExplicitEnum, - PaddingStrategy, - TensorType, - is_numpy_array, - requires_backends, - to_py_obj, -) - from transformers.utils import ( CHAT_TEMPLATE_DIR, CHAT_TEMPLATE_FILE, @@ -70,11 +59,12 @@ is_torch_tensor, list_repo_templates, logging, - ) from transformers.utils.chat_template_utils import render_jinja_template from transformers.utils.import_utils import PROTOBUF_IMPORT_ERROR +from . import __version__ +from .utils import ExplicitEnum, PaddingStrategy, TensorType, is_numpy_array, requires_backends, to_py_obj if TYPE_CHECKING: if is_torch_available(): @@ -109,9 +99,7 @@ class AddedToken: `tokenizers`. """ - def __init__( - self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None - ): + def __init__(self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None): self.content = content self.single_word = single_word self.lstrip = lstrip @@ -567,9 +555,7 @@ def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = return CharSpan(*span_indices) if span_indices is not None else None - def char_to_token( - self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0 - ) -> int: + def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int: """ Get the index of the token in the encoded output comprising a character in the original string for a sequence of the batch. @@ -716,9 +702,7 @@ def convert_to_tensors( # Get a function reference for the correct framework if tensor_type == TensorType.TENSORFLOW: if not is_tf_available(): - raise ImportError( - "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." - ) + raise ImportError("Unable to convert output to TensorFlow tensors format, TensorFlow is not installed.") import tensorflow as tf as_tensor = tf.constant @@ -752,6 +736,7 @@ def as_tensor(value, dtype=None): def is_tensor(obj): return isinstance(obj, mx.array) + else: def as_tensor(value, dtype=None): @@ -877,9 +862,9 @@ def __init__(self, verbose=False, **kwargs): if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == "additional_special_tokens": assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple" - assert all(isinstance(t, (str, AddedToken)) for t in value), ( - "One of the tokens is not a string or an AddedToken" - ) + assert all( + isinstance(t, (str, AddedToken)) for t in value + ), "One of the tokens is not a string or an AddedToken" setattr(self, key, value) elif isinstance(value, (str, AddedToken)): setattr(self, key, value) @@ -965,9 +950,9 @@ def add_special_tokens( logger.info(f"Assigning {value} to the {key} key of the tokenizer") if key == "additional_special_tokens": - assert isinstance(value, (list, tuple)) and all(isinstance(t, (str, AddedToken)) for t in value), ( - f"Tokens {value} for key {key} should all be str or AddedToken instances" - ) + assert isinstance(value, (list, tuple)) and all( + isinstance(t, (str, AddedToken)) for t in value + ), f"Tokens {value} for key {key} should all be str or AddedToken instances" to_add = [] for token in value: @@ -1434,7 +1419,9 @@ def __init__(self, **kwargs): # By default, do not split special tokens for both fast and slow tokenizers self.split_special_tokens = kwargs.pop("split_special_tokens", False) - self.deprecation_warnings = {} # Use to store when we have already noticed a deprecation warning (avoid overlogging). + self.deprecation_warnings = ( + {} + ) # Use to store when we have already noticed a deprecation warning (avoid overlogging). self._in_target_context_manager = False # Stores a Jinja template that formats chat histories into tokenizable strings @@ -1473,9 +1460,7 @@ def max_len_single_sentence(self, value) -> int: ) self.deprecation_warnings["max_len_single_sentence"] = True else: - raise ValueError( - "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up." - ) + raise ValueError("Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up.") @max_len_sentences_pair.setter def max_len_sentences_pair(self, value) -> int: @@ -1956,9 +1941,9 @@ def from_pretrained( if template_dir.is_dir(): for template_file in template_dir.glob("*.jinja"): template_name = template_file.name.removesuffix(".jinja") - vocab_files[f"chat_template_{template_name}"] = ( - f"{CHAT_TEMPLATE_DIR}/{template_file.name}" - ) + vocab_files[ + f"chat_template_{template_name}" + ] = f"{CHAT_TEMPLATE_DIR}/{template_file.name}" else: for template in list_repo_templates( pretrained_model_name_or_path, @@ -3352,9 +3337,9 @@ def pad( return BatchEncoding(encoded_inputs, tensor_type=return_tensors) batch_size = len(required_input) - assert all(len(v) == batch_size for v in encoded_inputs.values()), ( - "Some items in the output dictionary have a different batch size than others." - ) + assert all( + len(v) == batch_size for v in encoded_inputs.values() + ), "Some items in the output dictionary have a different batch size than others." if padding_strategy == PaddingStrategy.LONGEST: max_length = max(len(inputs) for inputs in required_input) @@ -3555,9 +3540,7 @@ def prepare_for_model( if return_length: encoded_inputs["length"] = len(encoded_inputs["input_ids"]) - batch_outputs = BatchEncoding( - encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis - ) + batch_outputs = BatchEncoding(encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis) return batch_outputs diff --git a/mindone/transformers/utils/backbone_utils.py b/mindone/transformers/utils/backbone_utils.py index 6ac303da34..693067ea0e 100644 --- a/mindone/transformers/utils/backbone_utils.py +++ b/mindone/transformers/utils/backbone_utils.py @@ -22,10 +22,10 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Optional, Union - if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig + class BackboneType(enum.Enum): TIMM = "timm" TRANSFORMERS = "transformers" diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index 0597f24e0f..ce1dd43a53 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -37,7 +37,7 @@ from .import_utils import is_mindspore_available if is_mindspore_available(): - import mindspore # noqa: F401 + import mindspore as ms # noqa: F401 _CAN_RECORD_REGISTRY = {} @@ -115,7 +115,7 @@ def _get_frameworks_and_test_func(x): def is_tensor(x): """ - Tests if `x` is a `mindspore.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray`, `np.ndarray` or `mlx.array` + Tests if `x` is a `ms.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray`, `np.ndarray` or `mlx.array` in the order defined by `infer_framework_from_repr` """ # This gives us a smart order to test the frameworks with the corresponding tests. @@ -141,7 +141,7 @@ def is_numpy_array(x): def _is_mindspore(x): import mindspore - return isinstance(x, mindspore.Tensor) + return isinstance(x, ms.Tensor) def is_mindspore_tensor(x): @@ -532,7 +532,6 @@ def tensor_size(array): raise ValueError(f"Type not supported for tensor_size: {type(array)}.") - def infer_framework(model_class): """ Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant @@ -692,7 +691,7 @@ class TransformersKwargs(TypedDict, total=False): Keyword arguments to be passed to the loss function Attributes: - num_items_in_batch (`Optional[mindspore.Tensor]`, *optional*): + num_items_in_batch (`Optional[ms.Tensor]`, *optional*): Number of items in the batch. It is recommended to pass it when you are doing gradient accumulation. output_hidden_states (`Optional[bool]`, *optional*): @@ -701,9 +700,9 @@ class TransformersKwargs(TypedDict, total=False): Turn this on to return the intermediary attention scores. output_router_logits (`Optional[bool]`, *optional*): For MoE models, this allows returning the router logits to compute the loss. - cumulative_seqlens_q (`mindspore.Tensor`, *optional*) + cumulative_seqlens_q (`ms.Tensor`, *optional*) Gets cumulative sequence length for query state. - cumulative_seqlens_k (`mindspore.Tensor`, *optional*) + cumulative_seqlens_k (`ms.Tensor`, *optional*) Gets cumulative sequence length for key state. max_length_q (`int`, *optional*): Maximum sequence length for query state. @@ -711,12 +710,12 @@ class TransformersKwargs(TypedDict, total=False): Maximum sequence length for key state. """ - num_items_in_batch: Optional["mindspore.Tensor"] + num_items_in_batch: Optional["ms.Tensor"] output_hidden_states: Optional[bool] output_attentions: Optional[bool] output_router_logits: Optional[bool] - cumulative_seqlens_q: Optional["mindspore.Tensor"] - cumulative_seqlens_k: Optional["mindspore.Tensor"] + cumulative_seqlens_q: Optional["ms.Tensor"] + cumulative_seqlens_k: Optional["ms.Tensor"] max_length_q: Optional[int] max_length_k: Optional[int] @@ -976,4 +975,4 @@ class LossKwargs(TypedDict, total=False): you are doing gradient accumulation. """ - num_items_in_batch: Optional[mindspore.Tensor] + num_items_in_batch: Optional[ms.Tensor] diff --git a/mindone/transformers/utils/import_utils.py b/mindone/transformers/utils/import_utils.py index 83265524a0..0a0edcec6c 100644 --- a/mindone/transformers/utils/import_utils.py +++ b/mindone/transformers/utils/import_utils.py @@ -64,12 +64,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ else: return package_exists + _av_available = importlib.util.find_spec("av") is not None _decord_available = importlib.util.find_spec("decord") is not None _scipy_available = _is_package_available("scipy") _cv2_available = importlib.util.find_spec("cv2") is not None _yt_dlp_available = importlib.util.find_spec("yt_dlp") is not None + def is_mindspore_available(): _mindspore_available, _mindspore_version = _is_package_available("mindspore", return_version=True) return _mindspore_available @@ -91,6 +93,7 @@ def is_av_available(): def is_decord_available(): return _decord_available + def is_cv2_available(): return _cv2_available @@ -98,6 +101,7 @@ def is_cv2_available(): def is_yt_dlp_available(): return _yt_dlp_available + @lru_cache def is_vision_available(): _pil_available = importlib.util.find_spec("PIL") is not None diff --git a/mindone/transformers/video_utils.py b/mindone/transformers/video_utils.py index 4106f04553..407bb10e04 100644 --- a/mindone/transformers/video_utils.py +++ b/mindone/transformers/video_utils.py @@ -14,7 +14,6 @@ # limitations under the License. import os -import warnings from collections.abc import Iterable from contextlib import redirect_stdout from dataclasses import dataclass @@ -24,6 +23,7 @@ import numpy as np import requests +from transformers.utils import logging from .image_transforms import PaddingMode, to_channel_dimension_format from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image @@ -31,22 +31,21 @@ is_av_available, is_cv2_available, is_decord_available, + is_mindspore_available, + is_mindspore_tensor, is_numpy_array, is_vision_available, is_yt_dlp_available, requires_backends, - is_mindspore_available, - is_mindspore_tensor, ) -from transformers.utils import logging - if is_vision_available(): import PIL.Image import PIL.ImageOps if is_mindspore_available(): - import mindspore + import mindspore as ms + from mindspore import mint logger = logging.get_logger(__name__) @@ -656,9 +655,7 @@ def _expand_for_data_format(values): raise ValueError(f"Unsupported format: {values}") # add 0 for channel dimension - values = ( - ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0)) - ) + values = ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0)) # Add additional padding if there's a batch dimension values = (0, *values) if video.ndim == 5 else values From f9ea8ced0fb31f54223fd0105b849de2b4b5c8be Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:10:23 +0800 Subject: [PATCH 39/94] fix pre-commit errors --- .../image_processing_utils_fast.py | 3 +-- mindone/transformers/image_utils.py | 4 +++- mindone/transformers/masking_utils.py | 3 ++- .../modeling_flash_attention_utils.py | 7 +++++-- mindone/transformers/processing_utils.py | 4 +++- mindone/transformers/tokenization_utils.py | 3 ++- .../transformers/tokenization_utils_base.py | 17 ++++++++++------- mindone/transformers/utils/generic.py | 10 +++++----- mindone/transformers/video_utils.py | 18 +++++++++--------- 9 files changed, 40 insertions(+), 29 deletions(-) diff --git a/mindone/transformers/image_processing_utils_fast.py b/mindone/transformers/image_processing_utils_fast.py index dcca36084e..65b2b8eaba 100644 --- a/mindone/transformers/image_processing_utils_fast.py +++ b/mindone/transformers/image_processing_utils_fast.py @@ -24,7 +24,6 @@ from PIL import Image from transformers.utils import auto_docstring, logging - from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from .image_transforms import ( convert_to_rgb, @@ -55,8 +54,8 @@ if is_mindspore_available(): import mindspore as ms - from mindspore import mint import mindspore.mint.functional as F + from mindspore import mint from mindspore.dataset import vision from mindspore.dataset.vision import Inter as InterpolationMode diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py index 666d1686e0..a0bbb5bd3c 100644 --- a/mindone/transformers/image_utils.py +++ b/mindone/transformers/image_utils.py @@ -345,7 +345,9 @@ def infer_channel_dimension_format( if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels: logger.warning( - f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension." + f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the\ + [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) \ + parameter to assign the channel dimension." ) return ChannelDimension.FIRST elif image.shape[first_dim] in num_channels: diff --git a/mindone/transformers/masking_utils.py b/mindone/transformers/masking_utils.py index d4e60a204f..fa30c6256c 100644 --- a/mindone/transformers/masking_utils.py +++ b/mindone/transformers/masking_utils.py @@ -21,8 +21,9 @@ from transformers.configuration_utils import PretrainedConfig import mindspore as ms -from mindspore import mint import mindspore.mint.functional as F +from mindspore import mint + from .cache_utils import Cache from .modeling_attn_mask_utils import dtype_to_min from .utils.generic import GeneralInterface diff --git a/mindone/transformers/modeling_flash_attention_utils.py b/mindone/transformers/modeling_flash_attention_utils.py index a133452e84..49d58197de 100644 --- a/mindone/transformers/modeling_flash_attention_utils.py +++ b/mindone/transformers/modeling_flash_attention_utils.py @@ -151,9 +151,11 @@ def _upad_input( indices_q (`ms.Tensor`): The indices of non-masked tokens from the flattened input target sequence. (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. \ + `cu_seqlens` shape is (batch_size + 1,). (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e.\ + query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) @@ -318,6 +320,7 @@ def is_flash_attn_available(): def flash_attn_supports_top_left_mask(): raise NotImplementedError("flash_attn_supports_top_left_mask is not supported yet.") + class FlashAttentionKwargs(TypedDict, total=False): cumulative_seqlens_q: Optional[ms.Tensor] cumulative_seqlens_k: Optional[ms.Tensor] diff --git a/mindone/transformers/processing_utils.py b/mindone/transformers/processing_utils.py index 362e54d2cc..130f8b5100 100644 --- a/mindone/transformers/processing_utils.py +++ b/mindone/transformers/processing_utils.py @@ -1458,7 +1458,9 @@ def apply_chat_template( if kwargs.get("continue_final_message", False): if kwargs.get("add_generation_prompt", False): raise ValueError( - "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." + "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message \ + when you want the model to continue the final message, and add_generation_prompt when\ + you want to add a header that will prompt it to start a new assistant message instead." ) if kwargs.get("return_assistant_tokens_mask", False): raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") diff --git a/mindone/transformers/tokenization_utils.py b/mindone/transformers/tokenization_utils.py index 40347322fd..a3af3bc637 100644 --- a/mindone/transformers/tokenization_utils.py +++ b/mindone/transformers/tokenization_utils.py @@ -479,7 +479,8 @@ def added_tokens_decoder(self, value: dict[int, Union[AddedToken, str]]) -> dict for index, token in value.items(): if not isinstance(token, (str, AddedToken)) or not isinstance(index, int): raise TypeError( - f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, Union[AddedToken, str]}" + f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, \ + should be a dict of {int, Union[AddedToken, str]}" ) self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token diff --git a/mindone/transformers/tokenization_utils_base.py b/mindone/transformers/tokenization_utils_base.py index 46697b8adf..f856b56744 100644 --- a/mindone/transformers/tokenization_utils_base.py +++ b/mindone/transformers/tokenization_utils_base.py @@ -1623,7 +1623,9 @@ def apply_chat_template( if continue_final_message: if add_generation_prompt: raise ValueError( - "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." + "continue_final_message and add_generation_prompt are not compatible. \ + Use continue_final_message when you want the model to continue the final message, \ + and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." ) if return_assistant_tokens_mask: raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") @@ -2164,7 +2166,7 @@ def _from_pretrained( init_kwargs["__slow_tokenizer"] = slow_tokenizer init_kwargs["name_or_path"] = pretrained_model_name_or_path - #### Handle tokenizer serialization of added and special tokens + # Handle tokenizer serialization of added and special tokens added_tokens_decoder: dict[int, AddedToken] = {} added_tokens_map: dict[str, AddedToken] = {} # if we have info on the slow added tokens @@ -2927,7 +2929,7 @@ def _is_valid_text_input(t): ) if text_pair is not None and len(text) != len(text_pair): raise ValueError( - f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: " f" {len(text_pair)}." ) batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text @@ -3116,7 +3118,8 @@ def batch_encode_plus( Args: - batch_text_or_text_pairs (`list[str]`, `list[tuple[str, str]]`, `list[list[str]]`, `list[tuple[list[str], list[str]]]`, and for not-fast tokenizers, also `list[list[int]]`, `list[tuple[list[int], list[int]]]`): + batch_text_or_text_pairs (`list[str]`, `list[tuple[str, str]]`, `list[list[str]]`, \ + `list[tuple[list[str], list[str]]]`, and for not-fast tokenizers, also `list[list[int]]`, `list[tuple[list[int], list[int]]]`): Batch of sequences or pair of sequences to be encoded. This can be a list of string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see details in `encode_plus`). @@ -3645,7 +3648,7 @@ def truncate_sequences( ids = ids[ids_to_move:] pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None else: - raise ValueError(f"invalid truncation strategy:{self.truncation_side}") + raise ValueError(f"invalid truncation strategy: {self.truncation_side}") elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: if len(pair_ids) > num_tokens_to_remove: @@ -3657,7 +3660,7 @@ def truncate_sequences( overflowing_tokens = pair_ids[:window_len] pair_ids = pair_ids[num_tokens_to_remove:] else: - raise ValueError(f"invalid truncation strategy:{self.truncation_side}") + raise ValueError(f"invalid truncation strategy: {self.truncation_side}") else: logger.error( f"We need to remove {num_tokens_to_remove} to truncate the input " @@ -3746,7 +3749,7 @@ def _pad( encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input else: - raise ValueError(f"Invalid padding strategy:{padding_side}") + raise ValueError(f"Invalid padding strategy: {padding_side}") return encoded_inputs diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index ce1dd43a53..57b2cccd26 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -37,7 +37,7 @@ from .import_utils import is_mindspore_available if is_mindspore_available(): - import mindspore as ms # noqa: F401 + import mindspore as ms # noqa: F401 _CAN_RECORD_REGISTRY = {} @@ -139,7 +139,7 @@ def is_numpy_array(x): def _is_mindspore(x): - import mindspore + import mindspore as ms return isinstance(x, ms.Tensor) @@ -753,7 +753,7 @@ def is_timm_local_checkpoint(pretrained_model_path: str) -> bool: return False -def set_attribute_for_modules(module: "mindspore.nn.Cell", key: str, value: Any): +def set_attribute_for_modules(module: "ms.nn.Cell", key: str, value: Any): """ Set a value to a module and all submodules. """ @@ -762,7 +762,7 @@ def set_attribute_for_modules(module: "mindspore.nn.Cell", key: str, value: Any) set_attribute_for_modules(submodule, key, value) -def del_attribute_from_modules(module: "mindspore.nn.Cell", key: str): +def del_attribute_from_modules(module: "ms.nn.Cell", key: str): """ Delete a value from a module and all submodules. """ @@ -808,7 +808,7 @@ class OutputRecorder: layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn". """ - target_class: "type[mindspore.nn.Cell]" + target_class: "type[ms.nn.Cell]" index: Optional[int] = 0 layer_name: Optional[str] = None diff --git a/mindone/transformers/video_utils.py b/mindone/transformers/video_utils.py index 407bb10e04..f8735c1ad2 100644 --- a/mindone/transformers/video_utils.py +++ b/mindone/transformers/video_utils.py @@ -53,12 +53,12 @@ VideoInput = Union[ list["PIL.Image.Image"], "np.ndarray", - "mindspore.Tensor", + "ms.Tensor", list["np.ndarray"], - list["mindspore.Tensor"], + list["ms.Tensor"], list[list["PIL.Image.Image"]], list[list["np.ndarrray"]], - list[list["mindspore.Tensor"]], + list[list["ms.Tensor"]], ] # noqa @@ -113,7 +113,7 @@ def is_scaled_video(video: np.ndarray) -> bool: return np.min(video) >= 0 and np.max(video) <= 1 -def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union["np.ndarray", "mindspore.Tensor"]]: +def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union["np.ndarray", "ms.Tensor"]]: """ Given a batch of videos, converts each video to a 4D array. If video is already in array type, it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element. @@ -134,7 +134,7 @@ def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union["np.ndar return video_converted -def make_batched_videos(videos) -> list[Union["np.ndarray", "mindspore.Tensor"]]: +def make_batched_videos(videos) -> list[Union["np.ndarray", "ms.Tensor"]]: """ Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1. If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image` @@ -681,8 +681,8 @@ def _expand_for_data_format(values): def group_videos_by_shape( - videos: list["mindspore.Tensor"], -) -> tuple[dict[tuple[int, int], list["mindspore.Tensor"]], dict[int, tuple[tuple[int, int], int]]]: + videos: list["ms.Tensor"], +) -> tuple[dict[tuple[int, int], list["ms.Tensor"]], dict[int, tuple[tuple[int, int], int]]]: """ Groups videos by shape. Returns a dictionary with the shape as key and a list of videos with that shape as value, @@ -704,8 +704,8 @@ def group_videos_by_shape( def reorder_videos( - processed_videos: dict[tuple[int, int], "mindspore.Tensor"], grouped_videos_index: dict[int, tuple[int, int]] -) -> list["mindspore.Tensor"]: + processed_videos: dict[tuple[int, int], "ms.Tensor"], grouped_videos_index: dict[int, tuple[int, int]] +) -> list["ms.Tensor"]: """ Reconstructs a list of videos in the original order. """ From 7bed7a1080825046e02f853411ff06ba802292ec Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:10:38 +0800 Subject: [PATCH 40/94] add modeling_layers.py from cui yushi --- mindone/transformers/modeling_layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mindone/transformers/modeling_layers.py b/mindone/transformers/modeling_layers.py index 1f23264b2b..5eb6d01dcf 100644 --- a/mindone/transformers/modeling_layers.py +++ b/mindone/transformers/modeling_layers.py @@ -17,6 +17,8 @@ from abc import ABC from typing import Optional +from transformers.utils import auto_docstring, can_return_tuple + import mindspore as ms import mindspore.nn as nn from mindspore import mint @@ -31,7 +33,6 @@ from .models.auto import AutoModel from .processing_utils import Unpack from .utils import TransformersKwargs, logging -from transformers.utils import auto_docstring, can_return_tuple logger = logging.get_logger(__name__) @@ -259,4 +260,4 @@ def construct( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) + ) \ No newline at end of file From 61b4f5c036a4a075af5d1c96b54a1d17207d0a4c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:18:17 +0800 Subject: [PATCH 41/94] fix import in transformers --- mindone/transformers/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 70a06c1910..c20f152172 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -26,11 +26,16 @@ from packaging import version from .cache_utils import * +from .feature_extraction_sequence_utils import SequenceFeatureExtractor + # Feature Extractor from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .image_processing_base import ImageProcessingMixin from .image_processing_utils import BaseImageProcessor +from .image_processing_utils_fast import BaseImageProcessorFast from .image_utils import ImageFeatureExtractionMixin +from .masking_utils import AttentionMaskInterface +from .modeling_layers import GradientCheckpointingLayer from .modeling_utils import MSPreTrainedModel from .models.albert import ( AlbertForMaskedLM, From 3e3f452a54a740cf4fbd5b2b00f0c5b230da706b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 25 Aug 2025 10:45:04 +0800 Subject: [PATCH 42/94] rm tokenization_utils.py and tokenization_utils_base.py --- .../models/llava/processing_llava.py | 2 +- mindone/transformers/tokenization_utils.py | 1136 ----- .../transformers/tokenization_utils_base.py | 4135 ----------------- 3 files changed, 1 insertion(+), 5272 deletions(-) delete mode 100644 mindone/transformers/tokenization_utils.py delete mode 100644 mindone/transformers/tokenization_utils_base.py diff --git a/mindone/transformers/models/llava/processing_llava.py b/mindone/transformers/models/llava/processing_llava.py index a6b29f75ac..974eebacd4 100644 --- a/mindone/transformers/models/llava/processing_llava.py +++ b/mindone/transformers/models/llava/processing_llava.py @@ -24,7 +24,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from ...utils import TensorType diff --git a/mindone/transformers/tokenization_utils.py b/mindone/transformers/tokenization_utils.py deleted file mode 100644 index a3af3bc637..0000000000 --- a/mindone/transformers/tokenization_utils.py +++ /dev/null @@ -1,1136 +0,0 @@ -# Copyright 2020 The HuggingFace Inc. team. -# -# This code is adapted from https://github.com/huggingface/transformers -# with modifications to run transformers on mindspore. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see -tokenization_utils_fast.py -""" - -import bisect -import itertools -import re -import unicodedata -from collections import OrderedDict -from typing import Any, Optional, Union, overload - -from .tokenization_utils_base import ( - ENCODE_KWARGS_DOCSTRING, - ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, - INIT_TOKENIZER_DOCSTRING, - AddedToken, - BatchEncoding, - EncodedInput, - EncodedInputPair, - PreTokenizedInput, - PreTokenizedInputPair, - PreTrainedTokenizerBase, - TextInput, - TextInputPair, - TruncationStrategy, -) -from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging - -logger = logging.get_logger(__name__) - -# Slow tokenizers are saved in a vocabulary plus three separated files -SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" -ADDED_TOKENS_FILE = "added_tokens.json" -TOKENIZER_CONFIG_FILE = "tokenizer_config.json" - - -class Trie: - """ - Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass - Loose reference https://en.wikipedia.org/wiki/Trie - """ - - def __init__(self, *args): - self.data = {} - self._tokens = set() - self._termination_char = "" - self.update(*args) - - def update(self, *args): - """ - Updates the Trie with new tokens provided as arguments. - - Args: - *args: Variable number of words to be added to the Trie. - """ - for token in tuple(*args): - self.add(token) - - def add(self, word: str): - """ - Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. - The special key `""` in `self._termination_char` is used to represent termination. - - This function is idempotent, adding twice the same word will leave the trie unchanged - - Example: - - ```python - >>> trie = Trie() - >>> trie.add("Hello 友達") - >>> trie.data - {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} - - >>> trie.add("Hello") - >>> trie.data - {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} - ``` - """ - if not word: - # Prevent empty string - return - - self._tokens.add(word) - ref = self.data - for char in word: - ref[char] = ref.setdefault(char, {}) - ref = ref[char] - ref[self._termination_char] = 1 - - def split(self, text: str) -> list[str]: - """ - Will look for the words added to the trie within `text`. Output is the original string splitted along the - boundaries of the words found. - - This trie will match the longest possible word first ! - - Example: - - ```python - >>> trie = Trie() - >>> trie.split("[CLS] This is a extra_id_100") - ["[CLS] This is a extra_id_100"] - - >>> trie.add("[CLS]") - >>> trie.add("extra_id_1") - >>> trie.add("extra_id_100") - >>> trie.split("[CLS] This is a extra_id_100") - ["[CLS]", " This is a ", "extra_id_100"] - ``` - """ - # indexes are counted left of the chars index. - # "hello", index 0, is left of h, index 1 is between h and e. - # index 5 is right of the "o". - - # States are going to capture every possible start (indexes as above) - # as keys, and have as values, a pointer to the position in the trie - # where we're at. This is a partial match for now. - # This enables to keep track of multiple matches while we're iterating - # the string - # If the trie contains, "blowing", and "lower" and we encounter the - # string "blower", we need to split into ["b", "lower"]. - # This is where we need to keep track of multiple possible starts. - states = OrderedDict() - - # This will contain every indices where we need - # to cut. - # We force to cut at offset 0 and len(text) (added later) - offsets = [0] - - # This is used by the lookahead which needs to skip over - # some text where the full match exceeded the place in the initial - # for loop - skip = 0 - # Main loop, Giving this algorithm O(n) complexity - for current, current_char in enumerate(text): - if skip and current < skip: - # Prevents the lookahead for matching twice - # like extra_id_100 and id_100 - continue - - # This will track every state - # that stop matching, we need to stop tracking them. - # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then - # fail on "b", we need to remove 0 from the valid states. - to_remove = set() - # Whenever we found a match, we need to drop everything - # this is a greedy algorithm, it will match on the first found token - reset = False - - # In this case, we already have partial matches (But unfinished) - for start, trie_pointer in states.items(): - if "" in trie_pointer: - # This is a final match, we need to reset and - # store the results in `offsets`. - - # Lookahead to match longest first - # Important in case of extra_id_1 vs extra_id_100 - # Here we are also actively looking for other earlier partial - # matches - # "[CLS]", "L", we need to match CLS even if L is special - for lookstart, looktrie_pointer in states.items(): - if lookstart > start: - # This partial match is later, we can stop looking - break - elif lookstart < start: - # This partial match is earlier, the trie pointer - # was already updated, so index is + 1 - lookahead_index = current + 1 - end = current + 1 - else: - # Here lookstart == start and - # looktrie_pointer == trie_pointer - # It wasn't updated yet so indices are current ones - lookahead_index = current - end = current - next_char = text[lookahead_index] if lookahead_index < len(text) else None - if "" in looktrie_pointer: - start = lookstart - end = lookahead_index - skip = lookahead_index - - while next_char in looktrie_pointer: - looktrie_pointer = looktrie_pointer[next_char] - lookahead_index += 1 - if "" in looktrie_pointer: - start = lookstart - end = lookahead_index - skip = lookahead_index - - if lookahead_index == len(text): - # End of string - break - next_char = text[lookahead_index] - # End lookahead - - # Storing and resetting - offsets.append(start) - offsets.append(end) - reset = True - break - elif current_char in trie_pointer: - # The current character being looked at has a match within the trie - # update the pointer (it will be stored back into states later). - trie_pointer = trie_pointer[current_char] - - # Storing back the new pointer into the states. - # Partial matches got longer by one. - states[start] = trie_pointer - else: - # The new character has not match in the trie, we need - # to stop keeping track of this partial match. - # We can't do it directly within the loop because of how - # python iteration works - to_remove.add(start) - - # Either clearing the full start (we found a real match) - # Or clearing only the partial matches that didn't work. - if reset: - states = {} - else: - for start in to_remove: - del states[start] - - # If this character is a starting character within the trie - # start keeping track of this partial match. - if current >= skip and current_char in self.data: - states[current] = self.data[current_char] - - # We have a cut at the end with states. - for start, trie_pointer in states.items(): - if "" in trie_pointer: - # This is a final match, we need to reset and - # store the results in `offsets`. - end = len(text) - offsets.append(start) - offsets.append(end) - # Longest cut is always the one with lower start so the first - # item so we need to break. - break - - return self.cut_text(text, offsets) - - def cut_text(self, text, offsets): - # We have all the offsets now, we just need to do the actual splitting. - # We need to eventually add the first part of the string and the eventual - # last part. - offsets.append(len(text)) - tokens = [] - start = 0 - for end in offsets: - if start > end: - logger.error( - "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it" - " anyway." - ) - continue - elif start == end: - # This might happen if there's a match at index 0 - # we're also preventing zero-width cuts in case of two - # consecutive matches - continue - tokens.append(text[start:end]) - start = end - - return tokens - - -class ExtensionsTrie(Trie): - def __init__(self, *args): - super().__init__(*args) - - def extensions(self, prefix: str): - """ - Generates all extensions of a given prefix token in the Trie. - - Example: - - ```python - >>> trie = Trie() - >>> trie.add("apple") - >>> trie.add("app") - >>> trie.add("application") - >>> trie.extensions("app") - ['app', 'apple', 'application'] - ``` - """ - prefix_node = self._get_node(prefix) - ret = self._collect_tokens(prefix_node) - return [prefix + token for token in ret] - - def _get_node(self, token: str) -> dict: - """ - Retrieves the node corresponding to the given token in the Trie. - - Args: - token (str): The token for which the corresponding node needs to be retrieved. - - Returns: - dict: The node in the Trie corresponding to the given token. - """ - node = self.data - for char in token: - if char not in node: - break - - node = node[char] - return node - - def _collect_tokens(self, node: dict) -> list: - """ - Generates all tokens in the Trie starting from a given node. - - Args: - node (dict): The node in the Trie from which tokens need to be generated. - - Returns: - list: List of tokens generated from the given node. - """ - tokens = [self._termination_char] if self._termination_char in node else [] - for token, subtrie_head in node.items(): - if token != self._termination_char: - subtokens = self._collect_tokens(subtrie_head) - tokens.extend([token + subtoken for subtoken in subtokens]) - return tokens - - -def _is_whitespace(char): - """Checks whether `char` is a whitespace character.""" - # \t, \n, and \r are technically control characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `char` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat.startswith("C"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `char` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False - - -def _is_end_of_word(text): - """Checks whether the last character in text is one of a punctuation, control or whitespace character.""" - last_char = text[-1] - return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char)) - - -def _is_start_of_word(text): - """Checks whether the first character in text is one of a punctuation, control or whitespace character.""" - first_char = text[0] - return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char)) - - -def _insert_one_token_to_ordered_list(token_list: list[str], new_token: str): - """ - Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted. - """ - insertion_idx = bisect.bisect_left(token_list, new_token) - # Checks if new_token is already in the ordered token_list - if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token: - # new_token is in token_list, don't add - return - else: - token_list.insert(insertion_idx, new_token) - - -@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) -class PreTrainedTokenizer(PreTrainedTokenizerBase): - """ - Base class for all slow tokenizers. - - Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`]. - - Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading - pretrained tokenizers as well as adding tokens to the vocabulary. - - This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the - specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). - """ - - def __init__(self, **kwargs): - # 1. Init the parent class - - self.tokens_trie = Trie() - - # 2. init `_added_tokens_decoder` if child class did not - if not hasattr(self, "_added_tokens_decoder"): - self._added_tokens_decoder: dict[int, AddedToken] = {} - - # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite - self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {})) - self._added_tokens_encoder: dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()} - - # 4 init the parent class - super().__init__(**kwargs) - - # 4. If some of the special tokens are not part of the vocab, we add them, at the end. - # the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers` - self._add_tokens( - [token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder], - special_tokens=True, - ) - - self._decode_use_source_tokenizer = False - - @property - def is_fast(self) -> bool: - return False - - @property - def vocab_size(self) -> int: - """ - `int`: Size of the base vocabulary (without the added tokens). - """ - raise NotImplementedError - - @property - def added_tokens_encoder(self) -> dict[str, int]: - """ - Returns the sorted mapping from string to index. The added tokens encoder is cached for performance - optimisation in `self._added_tokens_encoder` for the slow tokenizers. - """ - return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])} - - @property - def added_tokens_decoder(self) -> dict[int, AddedToken]: - """ - Returns the added tokens in the vocabulary as a dictionary of index to AddedToken. - - Returns: - `dict[str, int]`: The added tokens. - """ - return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])) - - @added_tokens_decoder.setter - def added_tokens_decoder(self, value: dict[int, Union[AddedToken, str]]) -> dict[int, AddedToken]: - # Always raise an error if string because users should define the behavior - for index, token in value.items(): - if not isinstance(token, (str, AddedToken)) or not isinstance(index, int): - raise TypeError( - f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, \ - should be a dict of {int, Union[AddedToken, str]}" - ) - - self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token - self._added_tokens_encoder[str(token)] = index - self._update_total_vocab_size() - - def get_added_vocab(self) -> dict[str, int]: - """ - Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from - the fast call because for now we always add the tokens even if they are already in the vocabulary. This is - something we should change. - - Returns: - `dict[str, int]`: The added tokens. - """ - return self._added_tokens_encoder - - def __len__(self): - """ - Size of the full vocabulary with the added tokens. - """ - return self.total_vocab_size - - def _update_total_vocab_size(self): - """ - Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because - otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and - is only updated when adding tokens. - """ - self.total_vocab_size = len(self.get_vocab()) - - def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int: - """ - Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to - it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the - vocab which is why they have to be handled specifically. - - Args: - new_tokens (`list[str]`or `list[tokenizers.AddedToken]`): - Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary - (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part - of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the - stripping and normalization of this token. This is NOT possible in `tokenizers`. - special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the tokens should be added as special tokens. - - Returns: - `int`: The number of tokens actually added to the vocabulary. - - Examples: - - ```python - # Let's see how to increase the vocabulary of Bert model and tokenizer - tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") - model = BertModel.from_pretrained("google-bert/bert-base-uncased") - - num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"]) - print("We have added", num_added_toks, "tokens") - # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer. - model.resize_token_embeddings(len(tokenizer)) - ```""" - added_tokens = 0 - if new_tokens is None: - return added_tokens - # TODO this is fairly slow to improve! - current_vocab = self.get_vocab().copy() - new_idx = len(current_vocab) # only call this once, len gives the last index + 1 - for token in new_tokens: - if not isinstance(token, (str, AddedToken)): - raise TypeError(f"Token {token} is not a string but a {type(token)}.") - if str(token) == "": - continue - if isinstance(token, str): - if token in self._added_tokens_encoder: - continue - else: - # very important for fast and slow equivalence! - is_special = token in self.all_special_tokens or special_tokens - token = AddedToken(token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special) - elif special_tokens: - # doing token.special=True changes the normalization! will fix in rust - # this is important and the only reason why the AddedTokens in each class are normalized by default - token.__setstate__({"special": True, "normalized": token.normalized}) - if token in self._added_tokens_decoder: - continue - if not token.special and token.normalized and getattr(self, "do_lower_case", False): - # Normalize if requested - token.content = token.content.lower() - if token.content not in current_vocab: - token_index = new_idx + added_tokens - current_vocab[token.content] = token_index - added_tokens += 1 - else: - token_index = current_vocab[token.content] - - if token.special and str(token) not in self.all_special_tokens: - self._additional_special_tokens.append(token) - # the setter automatically updates the reverse map - self._added_tokens_decoder[token_index] = token - self._added_tokens_encoder[token.content] = token_index - if self.verbose: - logger.info(f"Adding {token} to the vocabulary") - - self._update_trie() - self._update_total_vocab_size() - return added_tokens - - def _update_trie(self, unique_no_split_tokens: Optional[str] = []): - for token in self._added_tokens_decoder.values(): - if token not in self.tokens_trie._tokens: - self.tokens_trie.add(token.content) - for token in unique_no_split_tokens: - if token not in self.tokens_trie._tokens: - self.tokens_trie.add(token) - - def num_special_tokens_to_add(self, pair: bool = False) -> int: - """ - Returns the number of added tokens when encoding a sequence with special tokens. - - - - This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put - this inside your training loop. - - - - Args: - pair (`bool`, *optional*, defaults to `False`): - Whether the number of added tokens should be computed in the case of a sequence pair or a single - sequence. - - Returns: - `int`: Number of special tokens added to sequences. - """ - token_ids_0 = [] - token_ids_1 = [] - return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) - - def tokenize(self, text: TextInput, **kwargs) -> list[str]: - """ - Converts a string into a sequence of tokens, using the tokenizer. - - Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies - (BPE/SentencePieces/WordPieces). Takes care of added tokens. - - Args: - text (`str`): - The sequence to be encoded. - **kwargs (additional keyword arguments): - Passed along to the model-specific `prepare_for_tokenization` preprocessing method. - - Returns: - `list[str]`: The list of tokens. - """ - split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens) - - text, kwargs = self.prepare_for_tokenization(text, **kwargs) - - if kwargs: - logger.warning(f"Keyword arguments {kwargs} not recognized.") - - if hasattr(self, "do_lower_case") and self.do_lower_case: - # convert non-special tokens to lowercase. Might be super slow as well? - escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)] - escaped_special_toks += [ - re.escape(s_tok.content) - for s_tok in (self._added_tokens_decoder.values()) - if not s_tok.special and s_tok.normalized - ] - pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" - text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) - - if split_special_tokens: - no_split_token = [] - tokens = [text] - else: - no_split_token = self._added_tokens_encoder.keys() # don't split on any of the added tokens - # "This is something else" - tokens = self.tokens_trie.split(text) - - # ["This is something", "", " else"] - for i, token in enumerate(tokens): - if token in no_split_token: - tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token], None) - left = tokens[i - 1] if i > 0 else None - right = tokens[i + 1] if i < len(tokens) - 1 else None - if isinstance(tok_extended, AddedToken): - if tok_extended.rstrip and right: - # A bit counter-intuitive but we strip the left of the string - # since tok_extended.rstrip means the special token is eating all white spaces on its right - tokens[i + 1] = right.lstrip() - # Strip white spaces on the left - if tok_extended.lstrip and left: - tokens[i - 1] = left.rstrip() # Opposite here - if tok_extended.single_word and left and left[-1] != " ": - tokens[i - 1] += token - tokens[i] = "" - elif tok_extended.single_word and right and right[0] != " ": - tokens[i + 1] = token + tokens[i + 1] - tokens[i] = "" - else: - raise ValueError( - f"{tok_extended} cannot be tokenized because it was not properly added" - f" to the tokenizer. This means that it is not an `AddedToken` but a {type(tok_extended)}" - ) - # ["This is something", "", "else"] - tokenized_text = [] - for token in tokens: - # Need to skip eventual empty (fully stripped) tokens - if not token: - continue - if token in no_split_token: - tokenized_text.append(token) - else: - tokenized_text.extend(self._tokenize(token)) - # ["This", " is", " something", "", "else"] - return tokenized_text - - def _tokenize(self, text, **kwargs): - """ - Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based - vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). - - Do NOT take care of added tokens. - """ - raise NotImplementedError - - def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, list[int]]: - """ - Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the - vocabulary. - - Args: - tokens (`str` or `list[str]`): One or several token(s) to convert to token id(s). - - Returns: - `int` or `list[int]`: The token id or list of token ids. - """ - if tokens is None: - return None - - if isinstance(tokens, str): - return self._convert_token_to_id_with_added_voc(tokens) - - ids = [] - for token in tokens: - ids.append(self._convert_token_to_id_with_added_voc(token)) - return ids - - def _convert_token_to_id_with_added_voc(self, token): - if token is None: - return None - - if token in self._added_tokens_encoder: - return self._added_tokens_encoder[token] - return self._convert_token_to_id(token) - - def _convert_token_to_id(self, token): - raise NotImplementedError - - def _encode_plus( - self, - text: Union[TextInput, PreTokenizedInput, EncodedInput], - text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, - add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - **kwargs, - ) -> BatchEncoding: - def get_input_ids(text): - if isinstance(text, str): - tokens = self.tokenize(text, **kwargs) - return self.convert_tokens_to_ids(tokens) - elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): - if is_split_into_words: - tokens = list( - itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) - ) - return self.convert_tokens_to_ids(tokens) - else: - return self.convert_tokens_to_ids(text) - elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): - return text - else: - if is_split_into_words: - raise ValueError( - f"Input {text} is not valid. Should be a string or a list/tuple of strings when" - " `is_split_into_words=True`." - ) - else: - raise ValueError( - f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" - " integers." - ) - - if return_offsets_mapping: - raise NotImplementedError( - "return_offset_mapping is not available when using Python tokenizers. " - "To use this feature, change your tokenizer to one deriving from " - "transformers.PreTrainedTokenizerFast. " - "More information on available tokenizers at " - "https://github.com/huggingface/transformers/pull/2674" - ) - - first_ids = get_input_ids(text) - second_ids = get_input_ids(text_pair) if text_pair is not None else None - - return self.prepare_for_model( - first_ids, - pair_ids=second_ids, - add_special_tokens=add_special_tokens, - padding=padding_strategy.value, - truncation=truncation_strategy.value, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_tensors=return_tensors, - prepend_batch_axis=True, - return_attention_mask=return_attention_mask, - return_token_type_ids=return_token_type_ids, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_length=return_length, - verbose=verbose, - ) - - def _batch_encode_plus( - self, - batch_text_or_text_pairs: Union[ - list[TextInput], - list[TextInputPair], - list[PreTokenizedInput], - list[PreTokenizedInputPair], - list[EncodedInput], - list[EncodedInputPair], - ], - add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - split_special_tokens: bool = False, - **kwargs, - ) -> BatchEncoding: - def get_input_ids(text): - if isinstance(text, str): - tokens = self.tokenize(text, **kwargs) - return self.convert_tokens_to_ids(tokens) - elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): - if is_split_into_words: - tokens = list( - itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) - ) - return self.convert_tokens_to_ids(tokens) - else: - return self.convert_tokens_to_ids(text) - elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): - return text - else: - raise ValueError( - "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." - ) - - if return_offsets_mapping: - raise NotImplementedError( - "return_offset_mapping is not available when using Python tokenizers. " - "To use this feature, change your tokenizer to one deriving from " - "transformers.PreTrainedTokenizerFast." - ) - - input_ids = [] - for ids_or_pair_ids in batch_text_or_text_pairs: - if not isinstance(ids_or_pair_ids, (list, tuple)): - ids, pair_ids = ids_or_pair_ids, None - elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)): - ids, pair_ids = ids_or_pair_ids, None - else: - ids, pair_ids = ids_or_pair_ids - - first_ids = get_input_ids(ids) - second_ids = get_input_ids(pair_ids) if pair_ids is not None else None - input_ids.append((first_ids, second_ids)) - - batch_outputs = self._batch_prepare_for_model( - input_ids, - add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_attention_mask=return_attention_mask, - return_token_type_ids=return_token_type_ids, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_length=return_length, - return_tensors=return_tensors, - verbose=verbose, - split_special_tokens=split_special_tokens, - ) - - return BatchEncoding(batch_outputs) - - @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) - def _batch_prepare_for_model( - self, - batch_ids_pairs: list[Union[PreTokenizedInputPair, tuple[list[int], None]]], - add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[str] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_length: bool = False, - verbose: bool = True, - split_special_tokens: bool = False, - ) -> BatchEncoding: - """ - Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It - adds special tokens, truncates sequences if overflowing while taking into account the special tokens and - manages a moving window (with user defined stride) for overflowing tokens - - Args: - batch_ids_pairs: list of tokenized input ids or input ids pairs - """ - - batch_outputs = {} - for first_ids, second_ids in batch_ids_pairs: - outputs = self.prepare_for_model( - first_ids, - second_ids, - add_special_tokens=add_special_tokens, - padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward - truncation=truncation_strategy.value, - max_length=max_length, - stride=stride, - pad_to_multiple_of=None, # we pad in batch afterward - padding_side=None, # we pad in batch afterward - return_attention_mask=False, # we pad in batch afterward - return_token_type_ids=return_token_type_ids, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_length=return_length, - return_tensors=None, # We convert the whole batch to tensors at the end - prepend_batch_axis=False, - verbose=verbose, - split_special_tokens=split_special_tokens, - ) - - for key, value in outputs.items(): - if key not in batch_outputs: - batch_outputs[key] = [] - batch_outputs[key].append(value) - - batch_outputs = self.pad( - batch_outputs, - padding=padding_strategy.value, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_attention_mask=return_attention_mask, - ) - - batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) - - return batch_outputs - - def prepare_for_tokenization( - self, text: str, is_split_into_words: bool = False, **kwargs - ) -> tuple[str, dict[str, Any]]: - """ - Performs any necessary transformations before tokenization. - - This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the - `kwargs` at the end of the encoding process to be sure all the arguments have been used. - - Args: - text (`str`): - The text to prepare. - is_split_into_words (`bool`, *optional*, defaults to `False`): - Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the - tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) - which it will tokenize. This is useful for NER or token classification. - kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to use for the tokenization. - - Returns: - `tuple[str, dict[str, Any]]`: The prepared text and the unused kwargs. - """ - return (text, kwargs) - - def get_special_tokens_mask( - self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False - ) -> list[int]: - """ - Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. - - Args: - token_ids_0 (`list[int]`): - List of ids of the first sequence. - token_ids_1 (`list[int]`, *optional*): - List of ids of the second sequence. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - if token_ids_1 is not None: - raise ValueError( - "You should not supply a second sequence if the provided sequence of " - "ids is already formatted with special tokens for the model." - ) - - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) - return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) - - @overload - def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: - ... - - @overload - def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: - ... - - def convert_ids_to_tokens( - self, ids: Union[int, list[int]], skip_special_tokens: bool = False - ) -> Union[str, list[str]]: - """ - Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and - added tokens. - - Args: - ids (`int` or `list[int]`): - The token id (or token ids) to convert to tokens. - skip_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to remove special tokens in the decoding. - - Returns: - `str` or `list[str]`: The decoded token(s). - """ - if isinstance(ids, int): - if ids in self._added_tokens_decoder: - return self._added_tokens_decoder[ids].content - else: - return self._convert_id_to_token(ids) - tokens = [] - for index in ids: - index = int(index) - if skip_special_tokens and index in self.all_special_ids: - continue - if index in self._added_tokens_decoder: - tokens.append(self._added_tokens_decoder[index].content) - else: - tokens.append(self._convert_id_to_token(index)) - return tokens - - def _convert_id_to_token(self, index: int) -> str: - raise NotImplementedError - - def convert_tokens_to_string(self, tokens: list[str]) -> str: - return " ".join(tokens) - - def _decode( - self, - token_ids: Union[int, list[int]], - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: Optional[bool] = None, - spaces_between_special_tokens: bool = True, - **kwargs, - ) -> str: - self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) - - filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) - # If given is a single id, prevents splitting the string in upcoming loop - if isinstance(filtered_tokens, str): - filtered_tokens = [filtered_tokens] - - legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | { - token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size - } - # To avoid mixing byte-level and unicode for byte-level BPT - # we need to build string separately for added tokens and byte-level tokens - # cf. https://github.com/huggingface/transformers/issues/1133 - sub_texts = [] - current_sub_text = [] - # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string - for token in filtered_tokens: - if skip_special_tokens and token in self.all_special_tokens: - continue - if token in legacy_added_tokens: - if current_sub_text: - string = self.convert_tokens_to_string(current_sub_text) - if len(string) > 0: - sub_texts.append(string) - current_sub_text = [] - sub_texts.append(token) - else: - current_sub_text.append(token) - if current_sub_text: - sub_texts.append(self.convert_tokens_to_string(current_sub_text)) - - if spaces_between_special_tokens: - text = " ".join(sub_texts) - else: - text = "".join(sub_texts) - - clean_up_tokenization_spaces = ( - clean_up_tokenization_spaces - if clean_up_tokenization_spaces is not None - else self.clean_up_tokenization_spaces - ) - if clean_up_tokenization_spaces: - clean_text = self.clean_up_tokenization(text) - return clean_text - else: - return text diff --git a/mindone/transformers/tokenization_utils_base.py b/mindone/transformers/tokenization_utils_base.py deleted file mode 100644 index f856b56744..0000000000 --- a/mindone/transformers/tokenization_utils_base.py +++ /dev/null @@ -1,4135 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# This code is adapted from https://github.com/huggingface/transformers -# with modifications to run transformers on mindspore. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Base classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user -fronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary -of output with special method for the Fast tokenizers) -""" - -import copy -import json -import os -import re -import warnings -from collections import UserDict -from collections.abc import Mapping, Sequence, Sized -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union - -import numpy as np -from packaging import version -from transformers.dynamic_module_utils import custom_object_save -from transformers.utils import ( - CHAT_TEMPLATE_DIR, - CHAT_TEMPLATE_FILE, - PushToHubMixin, - add_end_docstrings, - cached_file, - copy_func, - download_url, - extract_commit_hash, - is_flax_available, - is_jax_tensor, - is_mlx_available, - is_offline_mode, - is_protobuf_available, - is_remote_url, - is_tf_available, - is_tf_tensor, - is_tokenizers_available, - is_torch_available, - is_torch_device, - is_torch_tensor, - list_repo_templates, - logging, -) -from transformers.utils.chat_template_utils import render_jinja_template -from transformers.utils.import_utils import PROTOBUF_IMPORT_ERROR - -from . import __version__ -from .utils import ExplicitEnum, PaddingStrategy, TensorType, is_numpy_array, requires_backends, to_py_obj - -if TYPE_CHECKING: - if is_torch_available(): - import torch - if is_tf_available(): - import tensorflow as tf - if is_flax_available(): - import jax.numpy as jnp # noqa: F401 - - -def import_protobuf_decode_error(error_message=""): - if is_protobuf_available(): - from google.protobuf.message import DecodeError - - return DecodeError - else: - raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message)) - - -if is_tokenizers_available(): - from tokenizers import AddedToken - from tokenizers import Encoding as EncodingFast -else: - - @dataclass(frozen=False, eq=True) - class AddedToken: - """ - AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the - way it should behave. - - The `normalized` will default to `not special` if it is not specified, similarly to the definition in - `tokenizers`. - """ - - def __init__(self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None): - self.content = content - self.single_word = single_word - self.lstrip = lstrip - self.rstrip = rstrip - self.special = special - self.normalized = normalized if normalized is not None else not special - - def __getstate__(self): - return self.__dict__ - - def __str__(self): - return self.content - - @dataclass - class EncodingFast: - """This is dummy class because without the `tokenizers` library we don't have these objects anyway""" - - pass - - -logger = logging.get_logger(__name__) - -VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input -LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER - -# Define type aliases and NamedTuples -TextInput = str -PreTokenizedInput = list[str] -EncodedInput = list[int] -TextInputPair = tuple[str, str] -PreTokenizedInputPair = tuple[list[str], list[str]] -EncodedInputPair = tuple[list[int], list[int]] - -# Define type aliases for text-related non-text modalities -AudioInput = Union["np.ndarray", "torch.Tensor", list["np.ndarray"], list["torch.Tensor"]] - -# Slow tokenizers used to be saved in three separated files -SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" -ADDED_TOKENS_FILE = "added_tokens.json" -TOKENIZER_CONFIG_FILE = "tokenizer_config.json" - -# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file -FULL_TOKENIZER_FILE = "tokenizer.json" -_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json") - - -class TruncationStrategy(ExplicitEnum): - """ - Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in - an IDE. - """ - - ONLY_FIRST = "only_first" - ONLY_SECOND = "only_second" - LONGEST_FIRST = "longest_first" - DO_NOT_TRUNCATE = "do_not_truncate" - - -class CharSpan(NamedTuple): - """ - Character span in the original string. - - Args: - start (`int`): Index of the first character in the original string. - end (`int`): Index of the character following the last character in the original string. - """ - - start: int - end: int - - -class TokenSpan(NamedTuple): - """ - Token span in an encoded string (list of tokens). - - Args: - start (`int`): Index of the first token in the span. - end (`int`): Index of the token following the last token in the span. - """ - - start: int - end: int - - -class BatchEncoding(UserDict): - """ - Holds the output of the [`~tokenization_utils_base.PreTrainedTokenizerBase.__call__`], - [`~tokenization_utils_base.PreTrainedTokenizerBase.encode_plus`] and - [`~tokenization_utils_base.PreTrainedTokenizerBase.batch_encode_plus`] methods (tokens, attention_masks, etc). - - This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes - utility methods to map from word/character space to token space. - - Args: - data (`dict`, *optional*): - Dictionary of lists/arrays/tensors returned by the `__call__`/`encode_plus`/`batch_encode_plus` methods - ('input_ids', 'attention_mask', etc.). - encoding (`tokenizers.Encoding` or `Sequence[tokenizers.Encoding]`, *optional*): - If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character - space to token space the `tokenizers.Encoding` instance or list of instance (for batches) hold this - information. - tensor_type (`Union[None, str, TensorType]`, *optional*): - You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at - initialization. - prepend_batch_axis (`bool`, *optional*, defaults to `False`): - Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). Note that this - parameter has an effect if the parameter `tensor_type` is set, *otherwise has no effect*. - n_sequences (`Optional[int]`, *optional*): - You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at - initialization. - """ - - def __init__( - self, - data: Optional[dict[str, Any]] = None, - encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None, - tensor_type: Union[None, str, TensorType] = None, - prepend_batch_axis: bool = False, - n_sequences: Optional[int] = None, - ): - super().__init__(data) - - if isinstance(encoding, EncodingFast): - encoding = [encoding] - - self._encodings = encoding - - if n_sequences is None and encoding is not None and encoding: - n_sequences = encoding[0].n_sequences - - self._n_sequences = n_sequences - - self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) - - @property - def n_sequences(self) -> Optional[int]: - """ - `Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this - [`BatchEncoding`]. Currently can be one of `None` (unknown), `1` (a single sentence) or `2` (a pair of - sentences) - """ - return self._n_sequences - - @property - def is_fast(self) -> bool: - """ - `bool`: Indicate whether this [`BatchEncoding`] was generated from the result of a [`PreTrainedTokenizerFast`] - or not. - """ - return self._encodings is not None - - def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]: - """ - If the key is a string, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', - etc.). - - If the key is an integer, get the `tokenizers.Encoding` for batch item with index `key`. - - If the key is a slice, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', etc.) - with the constraint of slice. - """ - if isinstance(item, str): - return self.data[item] - elif self._encodings is not None: - return self._encodings[item] - elif isinstance(item, slice): - return {key: self.data[key][item] for key in self.data.keys()} - else: - raise KeyError( - "Invalid key. Only three types of key are available: " - "(1) string, (2) integers for backend Encoding, and (3) slices for data subsetting." - ) - - def __getattr__(self, item: str): - try: - return self.data[item] - except KeyError: - raise AttributeError - - def __getstate__(self): - return {"data": self.data, "encodings": self._encodings} - - def __setstate__(self, state): - if "data" in state: - self.data = state["data"] - - if "encodings" in state: - self._encodings = state["encodings"] - - # After this point: - # Extended properties and methods only available for fast (Rust-based) tokenizers - # provided by HuggingFace tokenizers library. - - @property - def encodings(self) -> Optional[list[EncodingFast]]: - """ - `Optional[list[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns `None` if - the input was tokenized through Python (i.e., not a fast) tokenizer. - """ - return self._encodings - - def tokens(self, batch_index: int = 0) -> list[str]: - """ - Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to - integer indices) at a given batch index (only works for the output of a fast tokenizer). - - Args: - batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. - - Returns: - `list[str]`: The list of tokens at that index. - """ - if not self._encodings: - raise ValueError( - "tokens() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" - " class)." - ) - return self._encodings[batch_index].tokens - - def sequence_ids(self, batch_index: int = 0) -> list[Optional[int]]: - """ - Return a list mapping the tokens to the id of their original sentences: - - - `None` for special tokens added around or between sequences, - - `0` for tokens corresponding to words in the first sequence, - - `1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly - encoded. - - Args: - batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. - - Returns: - `list[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added - by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding - sequence. - """ - if not self._encodings: - raise ValueError( - "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" - " class)." - ) - return self._encodings[batch_index].sequence_ids - - def words(self, batch_index: int = 0) -> list[Optional[int]]: - """ - Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. - - Args: - batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. - - Returns: - `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the - tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word - (several tokens will be mapped to the same word index if they are parts of that word). - """ - if not self._encodings: - raise ValueError( - "words() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" - " class)." - ) - warnings.warn( - "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " - "but more self-explanatory `BatchEncoding.word_ids()` property.", - FutureWarning, - ) - return self.word_ids(batch_index) - - def word_ids(self, batch_index: int = 0) -> list[Optional[int]]: - """ - Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. - - Args: - batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. - - Returns: - `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the - tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word - (several tokens will be mapped to the same word index if they are parts of that word). - """ - if not self._encodings: - raise ValueError( - "word_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" - " class)." - ) - return self._encodings[batch_index].word_ids - - def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: - """ - Get the index of the sequence represented by the given token. In the general use case, this method returns `0` - for a single sequence or the first sequence of a pair, and `1` for the second sequence of a pair - - Can be called as: - - - `self.token_to_sequence(token_index)` if batch size is 1 - - `self.token_to_sequence(batch_index, token_index)` if batch size is greater than 1 - - This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., - words are defined by the user). In this case it allows to easily associate encoded tokens with provided - tokenized words. - - Args: - batch_or_token_index (`int`): - Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of - the token in the sequence. - token_index (`int`, *optional*): - If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the - sequence. - - Returns: - `int`: Index of the word in the input sequence. - """ - - if not self._encodings: - raise ValueError("token_to_sequence() is not available when using Python based tokenizers") - if token_index is not None: - batch_index = batch_or_token_index - else: - batch_index = 0 - token_index = batch_or_token_index - if batch_index < 0: - batch_index = self._batch_size + batch_index - if token_index < 0: - token_index = self._seq_len + token_index - return self._encodings[batch_index].token_to_sequence(token_index) - - def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: - """ - Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch. - - Can be called as: - - - `self.token_to_word(token_index)` if batch size is 1 - - `self.token_to_word(batch_index, token_index)` if batch size is greater than 1 - - This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., - words are defined by the user). In this case it allows to easily associate encoded tokens with provided - tokenized words. - - Args: - batch_or_token_index (`int`): - Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of - the token in the sequence. - token_index (`int`, *optional*): - If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the - sequence. - - Returns: - `int`: Index of the word in the input sequence. - """ - - if not self._encodings: - raise ValueError("token_to_word() is not available when using Python based tokenizers") - if token_index is not None: - batch_index = batch_or_token_index - else: - batch_index = 0 - token_index = batch_or_token_index - if batch_index < 0: - batch_index = self._batch_size + batch_index - if token_index < 0: - token_index = self._seq_len + token_index - return self._encodings[batch_index].token_to_word(token_index) - - def word_to_tokens( - self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 - ) -> Optional[TokenSpan]: - """ - Get the encoded token span corresponding to a word in a sequence of the batch. - - Token spans are returned as a [`~tokenization_utils_base.TokenSpan`] with: - - - **start** -- Index of the first token. - - **end** -- Index of the token following the last token. - - Can be called as: - - - `self.word_to_tokens(word_index, sequence_index: int = 0)` if batch size is 1 - - `self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)` if batch size is greater or equal to - 1 - - This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words - are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized - words. - - Args: - batch_or_word_index (`int`): - Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of - the word in the sequence. - word_index (`int`, *optional*): - If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the - sequence. - sequence_index (`int`, *optional*, defaults to 0): - If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 - or 1) the provided word index belongs to. - - Returns: - ([`~tokenization_utils_base.TokenSpan`], *optional*): Span of tokens in the encoded sequence. Returns - `None` if no tokens correspond to the word. This can happen especially when the token is a special token - that has been used to format the tokenization. For example when we add a class token at the very beginning - of the tokenization. - """ - - if not self._encodings: - raise ValueError("word_to_tokens() is not available when using Python based tokenizers") - if word_index is not None: - batch_index = batch_or_word_index - else: - batch_index = 0 - word_index = batch_or_word_index - if batch_index < 0: - batch_index = self._batch_size + batch_index - if word_index < 0: - word_index = self._seq_len + word_index - span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index) - return TokenSpan(*span) if span is not None else None - - def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> Optional[CharSpan]: - """ - Get the character span corresponding to an encoded token in a sequence of the batch. - - Character spans are returned as a [`~tokenization_utils_base.CharSpan`] with: - - - **start** -- Index of the first character in the original string associated to the token. - - **end** -- Index of the character following the last character in the original string associated to the - token. - - Can be called as: - - - `self.token_to_chars(token_index)` if batch size is 1 - - `self.token_to_chars(batch_index, token_index)` if batch size is greater or equal to 1 - - Args: - batch_or_token_index (`int`): - Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of - the token in the sequence. - token_index (`int`, *optional*): - If a batch index is provided in *batch_or_token_index*, this can be the index of the token or tokens in - the sequence. - - Returns: - [`~tokenization_utils_base.CharSpan`]: Span of characters in the original string, or None, if the token - (e.g. , ) doesn't correspond to any chars in the origin string. - """ - - if not self._encodings: - raise ValueError("token_to_chars() is not available when using Python based tokenizers") - if token_index is not None: - batch_index = batch_or_token_index - else: - batch_index = 0 - token_index = batch_or_token_index - span_indices = self._encodings[batch_index].token_to_chars(token_index) - - return CharSpan(*span_indices) if span_indices is not None else None - - def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int: - """ - Get the index of the token in the encoded output comprising a character in the original string for a sequence - of the batch. - - Can be called as: - - - `self.char_to_token(char_index)` if batch size is 1 - - `self.char_to_token(batch_index, char_index)` if batch size is greater or equal to 1 - - This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words - are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized - words. - - Args: - batch_or_char_index (`int`): - Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of - the word in the sequence - char_index (`int`, *optional*): - If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the - sequence. - sequence_index (`int`, *optional*, defaults to 0): - If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 - or 1) the provided character index belongs to. - - - Returns: - `int`: Index of the token, or None if the char index refers to a whitespace only token and whitespace is - trimmed with `trim_offsets=True`. - """ - - if not self._encodings: - raise ValueError("char_to_token() is not available when using Python based tokenizers") - if char_index is not None: - batch_index = batch_or_char_index - else: - batch_index = 0 - char_index = batch_or_char_index - return self._encodings[batch_index].char_to_token(char_index, sequence_index) - - def word_to_chars( - self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 - ) -> CharSpan: - """ - Get the character span in the original string corresponding to given word in a sequence of the batch. - - Character spans are returned as a CharSpan NamedTuple with: - - - start: index of the first character in the original string - - end: index of the character following the last character in the original string - - Can be called as: - - - `self.word_to_chars(word_index)` if batch size is 1 - - `self.word_to_chars(batch_index, word_index)` if batch size is greater or equal to 1 - - Args: - batch_or_word_index (`int`): - Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of - the word in the sequence - word_index (`int`, *optional*): - If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the - sequence. - sequence_index (`int`, *optional*, defaults to 0): - If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 - or 1) the provided word index belongs to. - - Returns: - `CharSpan` or `list[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan - are NamedTuple with: - - - start: index of the first character associated to the token in the original string - - end: index of the character following the last character associated to the token in the original - string - """ - - if not self._encodings: - raise ValueError("word_to_chars() is not available when using Python based tokenizers") - if word_index is not None: - batch_index = batch_or_word_index - else: - batch_index = 0 - word_index = batch_or_word_index - return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index))) - - def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int: - """ - Get the word in the original string corresponding to a character in the original string of a sequence of the - batch. - - Can be called as: - - - `self.char_to_word(char_index)` if batch size is 1 - - `self.char_to_word(batch_index, char_index)` if batch size is greater than 1 - - This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words - are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized - words. - - Args: - batch_or_char_index (`int`): - Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of - the character in the original string. - char_index (`int`, *optional*): - If a batch index is provided in *batch_or_token_index*, this can be the index of the character in the - original string. - sequence_index (`int`, *optional*, defaults to 0): - If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 - or 1) the provided character index belongs to. - - - Returns: - `int` or `list[int]`: Index or indices of the associated encoded token(s). - """ - - if not self._encodings: - raise ValueError("char_to_word() is not available when using Python based tokenizers") - if char_index is not None: - batch_index = batch_or_char_index - else: - batch_index = 0 - char_index = batch_or_char_index - return self._encodings[batch_index].char_to_word(char_index, sequence_index) - - def convert_to_tensors( - self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False - ): - """ - Convert the inner content to tensors. - - Args: - tensor_type (`str` or [`~utils.TensorType`], *optional*): - The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If - `None`, no modification is done. - prepend_batch_axis (`int`, *optional*, defaults to `False`): - Whether or not to add the batch dimension during the conversion. - """ - if tensor_type is None: - return self - - # Convert to TensorType - if not isinstance(tensor_type, TensorType): - tensor_type = TensorType(tensor_type) - - # Get a function reference for the correct framework - if tensor_type == TensorType.TENSORFLOW: - if not is_tf_available(): - raise ImportError("Unable to convert output to TensorFlow tensors format, TensorFlow is not installed.") - import tensorflow as tf - - as_tensor = tf.constant - is_tensor = tf.is_tensor - elif tensor_type == TensorType.PYTORCH: - if not is_torch_available(): - raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") - import torch - - is_tensor = torch.is_tensor - - def as_tensor(value, dtype=None): - if isinstance(value, list) and isinstance(value[0], np.ndarray): - return torch.from_numpy(np.array(value)) - return torch.tensor(value) - - elif tensor_type == TensorType.JAX: - if not is_flax_available(): - raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") - import jax.numpy as jnp # noqa: F811 - - as_tensor = jnp.array - is_tensor = is_jax_tensor - - elif tensor_type == TensorType.MLX: - if not is_mlx_available(): - raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.") - import mlx.core as mx - - as_tensor = mx.array - - def is_tensor(obj): - return isinstance(obj, mx.array) - - else: - - def as_tensor(value, dtype=None): - if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)): - value_lens = [len(val) for val in value] - if len(set(value_lens)) > 1 and dtype is None: - # we have a ragged list so handle explicitly - value = as_tensor([np.asarray(val) for val in value], dtype=object) - return np.asarray(value, dtype=dtype) - - is_tensor = is_numpy_array - - # Do the tensor conversion in batch - for key, value in self.items(): - try: - if prepend_batch_axis: - value = [value] - - if not is_tensor(value): - tensor = as_tensor(value) - - # Removing this for now in favor of controlling the shape with `prepend_batch_axis` - # # at-least2d - # if tensor.ndim > 2: - # tensor = tensor.squeeze(0) - # elif tensor.ndim < 2: - # tensor = tensor[None, :] - - self[key] = tensor - except Exception as e: - if key == "overflowing_tokens": - raise ValueError( - "Unable to create tensor returning overflowing tokens of different lengths. " - "Please see if a fast version of this tokenizer is available to have this feature available." - ) from e - raise ValueError( - "Unable to create tensor, you should probably activate truncation and/or padding with" - " 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your" - f" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is" - " expected)." - ) from e - - return self - - def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": - """ - Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only). - - Args: - device (`str` or `torch.device`): The device to put the tensors on. - non_blocking (`bool`): Whether to perform the copy asynchronously. - - Returns: - [`BatchEncoding`]: The same instance after modification. - """ - requires_backends(self, ["torch"]) - - # This check catches things like APEX blindly calling "to" on all inputs to a module - # Otherwise it passes the casts down and casts the LongTensor containing the token idxs - # into a HalfTensor - if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = { - k: v.to(device=device, non_blocking=non_blocking) if hasattr(v, "to") and callable(v.to) else v - for k, v in self.data.items() - } - else: - logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") - return self - - -class SpecialTokensMixin: - """ - A mixin derived by [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`] to handle specific behaviors related to - special tokens. In particular, this class hold the attributes which can be used to directly access these special - tokens in a model-independent manner and allow to set and update the special tokens. - - Args: - bos_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing the beginning of a sentence. - eos_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing the end of a sentence. - unk_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing an out-of-vocabulary token. - sep_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token separating two different sentences in the same input (used by BERT for instance). - pad_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by - attention mechanisms or loss computation. - cls_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing the class of the input (used by BERT for instance). - mask_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing a masked token (used by masked-language modeling pretraining objectives, like - BERT). - additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): - A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be - skipped when decoding if `skip_special_tokens` is set to `True`. - """ - - SPECIAL_TOKENS_ATTRIBUTES = [ - "bos_token", - "eos_token", - "unk_token", - "sep_token", - "pad_token", - "cls_token", - "mask_token", - "additional_special_tokens", - ] - - def __init__(self, verbose=False, **kwargs): - self._pad_token_type_id = 0 - self.verbose = verbose - self._special_tokens_map = dict.fromkeys(self.SPECIAL_TOKENS_ATTRIBUTES) - self._special_tokens_map["additional_special_tokens"] = [] # for BC where it defaults to empty list - - # We directly set the hidden value to allow initialization with special tokens - # which are not yet in the vocabulary. Necessary for serialization/de-serialization - # TODO clean this up at some point (probably by switching to fast tokenizers) - - for key, value in kwargs.items(): - if value is None: - continue - if key in self.SPECIAL_TOKENS_ATTRIBUTES: - if key == "additional_special_tokens": - assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple" - assert all( - isinstance(t, (str, AddedToken)) for t in value - ), "One of the tokens is not a string or an AddedToken" - setattr(self, key, value) - elif isinstance(value, (str, AddedToken)): - setattr(self, key, value) - else: - raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}") - - def sanitize_special_tokens(self) -> int: - """ - The `sanitize_special_tokens` is now deprecated kept for backward compatibility and will be removed in - transformers v5. - """ - logger.warning_once("The `sanitize_special_tokens` will be removed in transformers v5.") - return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) - - def add_special_tokens( - self, - special_tokens_dict: dict[str, Union[str, AddedToken, Sequence[Union[str, AddedToken]]]], - replace_additional_special_tokens=True, - ) -> int: - """ - Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If - special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the - current vocabulary). - - When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the - model so that its embedding matrix matches the tokenizer. - - In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method. - - Using `add_special_tokens` will ensure your special tokens can be used in several ways: - - - Special tokens can be skipped when decoding using `skip_special_tokens = True`. - - Special tokens are carefully handled by the tokenizer (they are never split), similar to `AddedTokens`. - - You can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This - makes it easy to develop model-agnostic training and fine-tuning scripts. - - When possible, special tokens are already registered for provided pretrained models (for instance - [`BertTokenizer`] `cls_token` is already registered to be `'[CLS]'` and XLM's one is also registered to be - `''`). - - Args: - special_tokens_dict (dictionary *str* to *str*, `tokenizers.AddedToken`, or `Sequence[Union[str, AddedToken]]`): - Keys should be in the list of predefined special attributes: [`bos_token`, `eos_token`, `unk_token`, - `sep_token`, `pad_token`, `cls_token`, `mask_token`, `additional_special_tokens`]. - - Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer - assign the index of the `unk_token` to them). - replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`): - If `True`, the existing list of additional special tokens will be replaced by the list provided in - `special_tokens_dict`. Otherwise, `self._special_tokens_map["additional_special_tokens"]` is just extended. In the former - case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged - as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the - `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous - `additional_special_tokens` are still added tokens, and will not be split by the model. - - Returns: - `int`: Number of tokens added to the vocabulary. - - Examples: - - ```python - # Let's see how to add a new classification token to GPT-2 - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - model = GPT2Model.from_pretrained("openai-community/gpt2") - - special_tokens_dict = {"cls_token": ""} - - num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) - print("We have added", num_added_toks, "tokens") - # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. - model.resize_token_embeddings(len(tokenizer)) - - assert tokenizer.cls_token == "" - ```""" - if not special_tokens_dict: - return 0 - - added_tokens = [] - for key, value in special_tokens_dict.items(): - assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token" - - if self.verbose: - logger.info(f"Assigning {value} to the {key} key of the tokenizer") - - if key == "additional_special_tokens": - assert isinstance(value, (list, tuple)) and all( - isinstance(t, (str, AddedToken)) for t in value - ), f"Tokens {value} for key {key} should all be str or AddedToken instances" - - to_add = [] - for token in value: - if isinstance(token, str): - # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this - token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True) - if not replace_additional_special_tokens and str(token) in self.additional_special_tokens: - continue - to_add.append(token) - if replace_additional_special_tokens and len(to_add) > 0: - setattr(self, key, list(to_add)) - else: - self._special_tokens_map["additional_special_tokens"].extend(to_add) - added_tokens += to_add - - else: - if not isinstance(value, (str, AddedToken)): - raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance") - if isinstance(value, (str)): - # for legacy purpose we default to stripping. `False` depends on this - value = AddedToken(value, rstrip=False, lstrip=False, normalized=False, special=True) - if isinstance(value, AddedToken): - setattr(self, key, value) - if value not in added_tokens: - added_tokens.append(value) - - # if we are adding tokens that were not part of the vocab, we ought to add them - added_tokens = self.add_tokens(added_tokens, special_tokens=True) - return added_tokens - - def add_tokens( - self, new_tokens: Union[str, AddedToken, Sequence[Union[str, AddedToken]]], special_tokens: bool = False - ) -> int: - """ - Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to - it with indices starting from length of the current vocabulary and will be isolated before the tokenization - algorithm is applied. Added tokens and tokens from the vocabulary of the tokenization algorithm are therefore - not treated in the same way. - - Note, when adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix - of the model so that its embedding matrix matches the tokenizer. - - In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method. - - Args: - new_tokens (`str`, `tokenizers.AddedToken` or a sequence of *str* or `tokenizers.AddedToken`): - Tokens are only added if they are not already in the vocabulary. `tokenizers.AddedToken` wraps a string - token to let you personalize its behavior: whether this token should only match against a single word, - whether this token should strip all potential whitespaces on the left side, whether this token should - strip all potential whitespaces on the right side, etc. - special_tokens (`bool`, *optional*, defaults to `False`): - Can be used to specify if the token is a special token. This mostly change the normalization behavior - (special tokens like CLS or [MASK] are usually not lower-cased for instance). - - See details for `tokenizers.AddedToken` in HuggingFace tokenizers library. - - Returns: - `int`: Number of tokens added to the vocabulary. - - Examples: - - ```python - # Let's see how to increase the vocabulary of Bert model and tokenizer - tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased") - model = BertModel.from_pretrained("google-bert/bert-base-uncased") - - num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"]) - print("We have added", num_added_toks, "tokens") - # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. - model.resize_token_embeddings(len(tokenizer)) - ```""" - if not new_tokens: - return 0 - - if not isinstance(new_tokens, (list, tuple)): - new_tokens = [new_tokens] - - return self._add_tokens(new_tokens, special_tokens=special_tokens) - - def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int: - raise NotImplementedError - - @property - def pad_token_type_id(self) -> int: - """ - `int`: Id of the padding token type in the vocabulary. - """ - return self._pad_token_type_id - - def __setattr__(self, key, value): - key_without_id = key - key_is_special_id = key.endswith("_id") or key.endswith("_ids") - if key_is_special_id: - key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4] - - if self.__dict__.get("_special_tokens_map", None) is not None and any( - name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id] - ): - if key_is_special_id: - if value is not None: - value = ( - self.convert_ids_to_tokens(value) - if key != "additional_special_tokens" - else [self.convert_ids_to_tokens(val) for val in value] - ) - key = key_without_id - - if key != "additional_special_tokens" and not isinstance(value, (str, AddedToken)) and value is not None: - raise ValueError(f"Cannot set a non-string value as the {key}") - self._special_tokens_map[key] = value - else: - super().__setattr__(key, value) - - def __getattr__(self, key): - key_without_id = key - key_is_special_id = key.endswith("_id") or key.endswith("_ids") - if key_is_special_id: - key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4] - - if self.__dict__.get("_special_tokens_map", None) is not None and any( - name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id] - ): - _special_tokens_map = self.__dict__["_special_tokens_map"] - if not key_is_special_id: - if _special_tokens_map[key] is None: - if self.verbose: - logger.error(f"Using {key}, but it is not set yet.") - return None - value = _special_tokens_map[key] - return str(value) if key != "additional_special_tokens" else [str(tok) for tok in value] - else: - attr_as_tokens = getattr(self, key_without_id) - return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None - - if key not in self.__dict__: - raise AttributeError(f"{self.__class__.__name__} has no attribute {key}") - else: - return super().__getattr__(key) - - @property - def special_tokens_map(self) -> dict[str, Union[str, list[str]]]: - """ - `dict[str, Union[str, list[str]]]`: A dictionary mapping special token class attributes (`cls_token`, - `unk_token`, etc.) to their values (`''`, `''`, etc.). - - Convert potential tokens of `tokenizers.AddedToken` type to string. - """ - set_attr = {} - for attr in self.SPECIAL_TOKENS_ATTRIBUTES: - attr_value = getattr(self, attr) - if attr_value: - set_attr[attr] = attr_value - return set_attr - - @property - def special_tokens_map_extended(self) -> dict[str, Union[str, AddedToken, list[Union[str, AddedToken]]]]: - """ - `dict[str, Union[str, tokenizers.AddedToken, list[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping - special token class attributes (`cls_token`, `unk_token`, etc.) to their values (`''`, `''`, etc.). - - Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how - special tokens are tokenized. - """ - set_attr = {} - for attr in self.SPECIAL_TOKENS_ATTRIBUTES: - attr_value = self._special_tokens_map[attr] - if attr_value: - set_attr[attr] = attr_value - return set_attr - - @property - def all_special_tokens_extended(self) -> list[Union[str, AddedToken]]: - """ - `list[Union[str, tokenizers.AddedToken]]`: All the special tokens (`''`, `''`, etc.), the order has - nothing to do with the index of each tokens. If you want to know the correct indices, check - `self.added_tokens_encoder`. We can't create an order anymore as the keys are `AddedTokens` and not `Strings`. - - Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how - special tokens are tokenized. - """ - all_tokens = [] - seen = set() - for value in self.special_tokens_map_extended.values(): - if isinstance(value, (list, tuple)): - tokens_to_add = [token for token in value if str(token) not in seen] - else: - tokens_to_add = [value] if str(value) not in seen else [] - seen.update(map(str, tokens_to_add)) - all_tokens.extend(tokens_to_add) - return all_tokens - - @property - def all_special_tokens(self) -> list[str]: - """ - `list[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.). - - Convert tokens of `tokenizers.AddedToken` type to string. - """ - all_toks = [str(s) for s in self.all_special_tokens_extended] - return all_toks - - @property - def all_special_ids(self) -> list[int]: - """ - `list[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. - """ - all_toks = self.all_special_tokens - all_ids = self.convert_tokens_to_ids(all_toks) - return all_ids - - def _set_model_specific_special_tokens(self, special_tokens: list[str]): - """ - Adds new special tokens to the "SPECIAL_TOKENS_ATTRIBUTES" list which will be part - of "self.special_tokens" and saved as a special token in tokenizer's config. - This allows us to dynamically add new model-type specific tokens after initializing the tokenizer. - For example: if the model tokenizers is multimodal, we can support special image or audio tokens. - """ - self.SPECIAL_TOKENS_ATTRIBUTES = self.SPECIAL_TOKENS_ATTRIBUTES + list(special_tokens.keys()) - for key, value in special_tokens.items(): - if isinstance(value, (str, AddedToken)): - self._special_tokens_map[key] = value - else: - raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}") - - -ENCODE_KWARGS_DOCSTRING = r""" - add_special_tokens (`bool`, *optional*, defaults to `True`): - Whether or not to add special tokens when encoding the sequences. This will use the underlying - `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are - automatically added to the input ids. This is useful if you want to add `bos` or `eos` tokens - automatically. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Activates and controls padding. Accepts the following values: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence is provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): - Activates and controls truncation. Accepts the following values: - - - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or - to the maximum acceptable input length for the model if that argument is not provided. This will - truncate token by token, removing a token from the longest sequence in the pair if a pair of - sequences (or a batch of pairs) is provided. - - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the - maximum acceptable input length for the model if that argument is not provided. This will only - truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the - maximum acceptable input length for the model if that argument is not provided. This will only - truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths - greater than the model maximum admissible input size). - max_length (`int`, *optional*): - Controls the maximum length to use by one of the truncation/padding parameters. - - If left unset or set to `None`, this will use the predefined model maximum length if a maximum length - is required by one of the truncation/padding parameters. If the model has no specific maximum input - length (like XLNet) truncation/padding to a maximum length will be deactivated. - stride (`int`, *optional*, defaults to 0): - If set to a number along with `max_length`, the overflowing tokens returned when - `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence - returned to provide some overlap between truncated and overflowing sequences. The value of this - argument defines the number of overlapping tokens. - is_split_into_words (`bool`, *optional*, defaults to `False`): - Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the - tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) - which it will tokenize. This is useful for NER or token classification. - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated. - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta). - padding_side (`str`, *optional*): - The side on which the model should have padding applied. Should be selected between ['right', 'left']. - Default value is picked from the class attribute of the same name. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. -""" - -ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" - return_token_type_ids (`bool`, *optional*): - Whether to return token type IDs. If left to the default, will return the token type IDs according to - the specific tokenizer's default, defined by the `return_outputs` attribute. - - [What are token type IDs?](../glossary#token-type-ids) - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific tokenizer's default, defined by the `return_outputs` attribute. - - [What are attention masks?](../glossary#attention-mask) - return_overflowing_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch - of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead - of returning overflowing tokens. - return_special_tokens_mask (`bool`, *optional*, defaults to `False`): - Whether or not to return special tokens mask information. - return_offsets_mapping (`bool`, *optional*, defaults to `False`): - Whether or not to return `(char_start, char_end)` for each token. - - This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using - Python's tokenizer, this method will raise `NotImplementedError`. - return_length (`bool`, *optional*, defaults to `False`): - Whether or not to return the lengths of the encoded inputs. - verbose (`bool`, *optional*, defaults to `True`): - Whether or not to print more information and warnings. - **kwargs: passed to the `self.tokenize()` method - - Return: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: - - - **input_ids** -- List of token ids to be fed to a model. - - [What are input IDs?](../glossary#input-ids) - - - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or - if *"token_type_ids"* is in `self.model_input_names`). - - [What are token type IDs?](../glossary#token-type-ids) - - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when - `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). - - [What are attention masks?](../glossary#attention-mask) - - - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and - `return_overflowing_tokens=True`). - - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and - `return_overflowing_tokens=True`). - - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying - regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). - - **length** -- The length of the inputs (when `return_length=True`) -""" - - -INIT_TOKENIZER_DOCSTRING = r""" - Class attributes (overridden by derived classes) - - - **vocab_files_names** (`dict[str, str]`) -- A dictionary with, as keys, the `__init__` keyword name of each - vocabulary file required by the model, and as associated values, the filename for saving the associated file - (string). - - **pretrained_vocab_files_map** (`dict[str, dict[str, str]]`) -- A dictionary of dictionaries, with the - high-level keys being the `__init__` keyword name of each vocabulary file required by the model, the - low-level being the `short-cut-names` of the pretrained models with, as associated values, the `url` to the - associated pretrained vocabulary file. - - **model_input_names** (`list[str]`) -- A list of inputs expected in the forward pass of the model. - - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied. - Should be `'right'` or `'left'`. - - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation - applied. Should be `'right'` or `'left'`. - - Args: - model_max_length (`int`, *optional*): - The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is - loaded with [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], this will be set to the - value stored for the associated model in `max_model_input_sizes` (see above). If no value is provided, will - default to VERY_LARGE_INTEGER (`int(1e30)`). - padding_side (`str`, *optional*): - The side on which the model should have padding applied. Should be selected between ['right', 'left']. - Default value is picked from the class attribute of the same name. - truncation_side (`str`, *optional*): - The side on which the model should have truncation applied. Should be selected between ['right', 'left']. - Default value is picked from the class attribute of the same name. - chat_template (`str`, *optional*): - A Jinja template string that will be used to format lists of chat messages. See - https://huggingface.co/docs/transformers/chat_templating for a full description. - model_input_names (`list[string]`, *optional*): - The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or - `"attention_mask"`). Default value is picked from the class attribute of the same name. - bos_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing the beginning of a sentence. Will be associated to `self.bos_token` and - `self.bos_token_id`. - eos_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing the end of a sentence. Will be associated to `self.eos_token` and - `self.eos_token_id`. - unk_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing an out-of-vocabulary token. Will be associated to `self.unk_token` and - `self.unk_token_id`. - sep_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token separating two different sentences in the same input (used by BERT for instance). Will be - associated to `self.sep_token` and `self.sep_token_id`. - pad_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by - attention mechanisms or loss computation. Will be associated to `self.pad_token` and `self.pad_token_id`. - cls_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing the class of the input (used by BERT for instance). Will be associated to - `self.cls_token` and `self.cls_token_id`. - mask_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token representing a masked token (used by masked-language modeling pretraining objectives, like - BERT). Will be associated to `self.mask_token` and `self.mask_token_id`. - additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): - A tuple or a list of additional special tokens. Add them here to ensure they are skipped when decoding with - `skip_special_tokens` is set to True. If they are not part of the vocabulary, they will be added at the end - of the vocabulary. - clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): - Whether or not the model should cleanup the spaces that were added when splitting the input text during the - tokenization process. - split_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the special tokens should be split during the tokenization process. Passing will affect the - internal state of the tokenizer. The default behavior is to not split special tokens. This means that if - `` is the `bos_token`, then `tokenizer.tokenize("") = ['`]. Otherwise, if - `split_special_tokens=True`, then `tokenizer.tokenize("")` will be give `['<','s', '>']`. -""" - - -@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) -class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): - """ - Base class for [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`]. - - Handles shared (mostly boiler plate) methods for those two classes. - """ - - vocab_files_names: dict[str, str] = {} - pretrained_vocab_files_map: dict[str, dict[str, str]] = {} - _auto_class: Optional[str] = None - - # first name has to correspond to main model input name - # to make sure `tokenizer.pad(...)` works correctly - model_input_names: list[str] = ["input_ids", "token_type_ids", "attention_mask"] - padding_side: str = "right" - truncation_side: str = "right" - slow_tokenizer_class = None - - def __init__(self, **kwargs): - # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) - self.init_inputs = () - for key in kwargs: - if hasattr(self, key) and callable(getattr(self, key)): - raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") - - self.init_kwargs = copy.deepcopy(kwargs) - self.name_or_path = kwargs.pop("name_or_path", "") - self._processor_class = kwargs.pop("processor_class", None) - - # For backward compatibility we fallback to set model_max_length from max_len if provided - model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) - self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER - - # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it - # is changed. - self.padding_side = kwargs.pop("padding_side", self.padding_side) - if self.padding_side not in ["right", "left"]: - raise ValueError( - f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" - ) - - self.truncation_side = kwargs.pop("truncation_side", self.truncation_side) - if self.truncation_side not in ["right", "left"]: - raise ValueError( - f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}" - ) - - self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) - - # By default, cleaning tokenization spaces for both fast and slow tokenizers - self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False) - - # By default, do not split special tokens for both fast and slow tokenizers - self.split_special_tokens = kwargs.pop("split_special_tokens", False) - - self.deprecation_warnings = ( - {} - ) # Use to store when we have already noticed a deprecation warning (avoid overlogging). - self._in_target_context_manager = False - - # Stores a Jinja template that formats chat histories into tokenizable strings - self.chat_template = kwargs.pop("chat_template", None) - if isinstance(self.chat_template, (list, tuple)): - # Chat templates are stored as lists of dicts with fixed key names, - # we reconstruct that into a single dict while loading them. - self.chat_template = {template["name"]: template["template"] for template in self.chat_template} - - super().__init__(**kwargs) - - self.extra_special_tokens = kwargs.pop("extra_special_tokens", {}) - self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens) - - @property - def max_len_single_sentence(self) -> int: - """ - `int`: The maximum length of a sentence that can be fed to the model. - """ - return self.model_max_length - self.num_special_tokens_to_add(pair=False) - - @property - def max_len_sentences_pair(self) -> int: - """ - `int`: The maximum combined length of a pair of sentences that can be fed to the model. - """ - return self.model_max_length - self.num_special_tokens_to_add(pair=True) - - @max_len_single_sentence.setter - def max_len_single_sentence(self, value) -> int: - # For backward compatibility, allow to try to setup 'max_len_single_sentence'. - if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose: - if not self.deprecation_warnings.get("max_len_single_sentence", False): - logger.warning( - "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up." - ) - self.deprecation_warnings["max_len_single_sentence"] = True - else: - raise ValueError("Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up.") - - @max_len_sentences_pair.setter - def max_len_sentences_pair(self, value) -> int: - # For backward compatibility, allow to try to setup 'max_len_sentences_pair'. - if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose: - if not self.deprecation_warnings.get("max_len_sentences_pair", False): - logger.warning( - "Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up." - ) - self.deprecation_warnings["max_len_sentences_pair"] = True - else: - raise ValueError("Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up.") - - def _set_processor_class(self, processor_class: str): - """Sets processor class as an attribute.""" - self._processor_class = processor_class - - @property - def added_tokens_decoder(self) -> dict[int, AddedToken]: - raise NotImplementedError() - - def __repr__(self) -> str: - added_tokens_decoder_rep = "\n\t".join([f"{k}: {v.__repr__()}," for k, v in self.added_tokens_decoder.items()]) - return ( - f"{self.__class__.__name__}(name_or_path='{self.name_or_path}'," - f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast}," - f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}'," - f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces}," - " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)" - ) - - def __len__(self) -> int: - raise NotImplementedError() - - def get_vocab(self) -> dict[str, int]: - """ - Returns the vocabulary as a dictionary of token to index. - - `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the - vocab. - - Returns: - `dict[str, int]`: The vocabulary. - """ - raise NotImplementedError() - - def apply_chat_template( - self, - conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], - tools: Optional[list[Union[dict, Callable]]] = None, - documents: Optional[list[dict[str, str]]] = None, - chat_template: Optional[str] = None, - add_generation_prompt: bool = False, - continue_final_message: bool = False, - tokenize: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: bool = False, - max_length: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_dict: bool = False, - return_assistant_tokens_mask: bool = False, - tokenizer_kwargs: Optional[dict[str, Any]] = None, - **kwargs, - ) -> Union[str, list[int], list[str], list[list[int]], BatchEncoding]: - """ - Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token - ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to - determine the format and control tokens to use when converting. - - Args: - conversation (Union[list[dict[str, str]], list[list[dict[str, str]]]]): A list of dicts - with "role" and "content" keys, representing the chat history so far. - tools (`list[Union[Dict, Callable]]`, *optional*): - A list of tools (callable functions) that will be accessible to the model. If the template does not - support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, - giving the name, description and argument types for the tool. See our - [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) - for more information. - documents (`list[dict[str, str]]`, *optional*): - A list of dicts representing documents that will be accessible to the model if it is performing RAG - (retrieval-augmented generation). If the template does not support RAG, this argument will have no - effect. We recommend that each document should be a dict containing "title" and "text" keys. Please - see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) - for examples of passing documents with chat templates. - chat_template (`str`, *optional*): - A Jinja template to use for this conversion. It is usually not necessary to pass anything to this - argument, as the model's template will be used by default. - add_generation_prompt (bool, *optional*): - If this is set, a prompt with the token(s) that indicate - the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. - Note that this argument will be passed to the chat template, and so it must be supported in the - template for this argument to have any effect. - continue_final_message (bool, *optional*): - If this is set, the chat will be formatted so that the final - message in the chat is open-ended, without any EOS tokens. The model will continue this message - rather than starting a new one. This allows you to "prefill" part of - the model's response for it. Cannot be used at the same time as `add_generation_prompt`. - tokenize (`bool`, defaults to `True`): - Whether to tokenize the output. If `False`, the output will be a string. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - truncation (`bool`, defaults to `False`): - Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. - max_length (`int`, *optional*): - Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If - not specified, the tokenizer's `max_length` attribute will be used as a default. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable - values are: - - `'tf'`: Return TensorFlow `tf.Tensor` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - return_dict (`bool`, defaults to `False`): - Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. - tokenizer_kwargs (`dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. - return_assistant_tokens_mask (`bool`, defaults to `False`): - Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, - the mask will contain 1. For user and system tokens, the mask will contain 0. - This functionality is only available for chat templates that support it via the `{% generation %}` keyword. - **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. - - Returns: - `Union[list[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This - output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is - set, will return a dict of tokenizer outputs instead. - """ - - if return_dict and not tokenize: - raise ValueError( - "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict " - "of tokenizer outputs to return." - ) - - if return_assistant_tokens_mask and not return_dict: - raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`") - - if tokenizer_kwargs is None: - tokenizer_kwargs = {} - - chat_template = self.get_chat_template(chat_template, tools) - - if isinstance(conversation, (list, tuple)) and ( - isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") - ): - conversations = conversation - is_batched = True - else: - conversations = [conversation] - is_batched = False - - if continue_final_message: - if add_generation_prompt: - raise ValueError( - "continue_final_message and add_generation_prompt are not compatible. \ - Use continue_final_message when you want the model to continue the final message, \ - and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." - ) - if return_assistant_tokens_mask: - raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") - - template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present - rendered_chat, generation_indices = render_jinja_template( - conversations=conversations, - tools=tools, - documents=documents, - chat_template=chat_template, - return_assistant_tokens_mask=return_assistant_tokens_mask, - continue_final_message=continue_final_message, - add_generation_prompt=add_generation_prompt, - **template_kwargs, - ) - - if not is_batched: - rendered_chat = rendered_chat[0] - - if tokenize: - out = self( - rendered_chat, - padding=padding, - truncation=truncation, - max_length=max_length, - add_special_tokens=False, - return_tensors=return_tensors, - **tokenizer_kwargs, - ) - if return_dict: - if return_assistant_tokens_mask: - assistant_masks = [] - if is_batched or return_tensors: - input_ids = out["input_ids"] - else: - input_ids = [out["input_ids"]] - for i in range(len(input_ids)): - current_mask = [0] * len(input_ids[i]) - for assistant_start_char, assistant_end_char in generation_indices[i]: - start_token = out.char_to_token(i, assistant_start_char) - end_token = out.char_to_token(i, assistant_end_char - 1) - if start_token is None: - # start_token is out of bounds maybe due to truncation. - break - for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): - current_mask[token_id] = 1 - assistant_masks.append(current_mask) - - if not is_batched and not return_tensors: - assistant_masks = assistant_masks[0] - - out["assistant_masks"] = assistant_masks - - if return_tensors: - out.convert_to_tensors(tensor_type=return_tensors) - - return out - else: - return out["input_ids"] - else: - return rendered_chat - - def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[list[dict]] = None) -> str: - """ - Retrieve the chat template string used for tokenizing chat messages. This template is used - internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat - template for better generation tracking. - - Args: - chat_template (`str`, *optional*): - A Jinja template or the name of a template to use for this conversion. - It is usually not necessary to pass anything to this argument, - as the model's template will be used by default. - tools (`list[Dict]`, *optional*): - A list of tools (callable functions) that will be accessible to the model. If the template does not - support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, - giving the name, description and argument types for the tool. See our - [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) - for more information. - - Returns: - `str`: The chat template string. - """ - # First, handle the cases when the model has a dict of multiple templates - if isinstance(self.chat_template, dict): - template_dict = self.chat_template - if chat_template is not None and chat_template in template_dict: - # The user can pass the name of a template to the chat template argument instead of an entire template - chat_template = template_dict[chat_template] - elif chat_template is None: - if tools is not None and "tool_use" in template_dict: - chat_template = template_dict["tool_use"] - elif "default" in template_dict: - chat_template = template_dict["default"] - else: - raise ValueError( - "This model has multiple chat templates with no default specified! Please either pass a chat " - "template or the name of the template you wish to use to the `chat_template` argument. Available " - f"template names are {sorted(template_dict.keys())}." - ) - - elif chat_template is None: - # These are the cases when the model has a single template - # priority: `chat_template` argument > `tokenizer.chat_template` - if self.chat_template is not None: - chat_template = self.chat_template - else: - raise ValueError( - "Cannot use chat template functions because tokenizer.chat_template is not set and no template " - "argument was passed! For information about writing templates and setting the " - "tokenizer.chat_template attribute, please see the documentation at " - "https://huggingface.co/docs/transformers/main/en/chat_templating" - ) - - return chat_template - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Union[str, os.PathLike], - *init_inputs, - cache_dir: Optional[Union[str, os.PathLike]] = None, - force_download: bool = False, - local_files_only: bool = False, - token: Optional[Union[str, bool]] = None, - revision: str = "main", - trust_remote_code=False, - **kwargs, - ): - r""" - Instantiate a [`~tokenization_utils_base.PreTrainedTokenizerBase`] (or a derived class) from a predefined - tokenizer. - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. - - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved - using the [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`] method, e.g., - `./my_model_directory/`. - - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary - file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g., - `./my_model_directory/vocab.txt`. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download the vocabulary files and override the cached versions if they - exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `hf auth login` (stored in `~/.huggingface`). - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only rely on local files and not to attempt to download any files. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*): - In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for - facebook/rag-token-base), specify it here. - inputs (additional positional arguments, *optional*): - Will be passed along to the Tokenizer `__init__` method. - trust_remote_code (`bool`, *optional*, defaults to `False`): - Whether or not to allow for custom models defined on the Hub in their own modeling files. This option - should only be set to `True` for repositories you trust and in which you have read the code, as it will - execute code present on the Hub on your local machine. - kwargs (additional keyword arguments, *optional*): - Will be passed to the Tokenizer `__init__` method. Can be used to set special tokens like `bos_token`, - `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`, - `additional_special_tokens`. See parameters in the `__init__` for more details. - - - - Passing `token=True` is required when you want to use a private model. - - - - Examples: - - ```python - # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer - # Download vocabulary from huggingface.co and cache. - tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") - - # Download vocabulary from huggingface.co (user-uploaded) and cache. - tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-german-cased") - - # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*) - tokenizer = BertTokenizer.from_pretrained("./test/saved_model/") - - # If the tokenizer uses a single vocabulary file, you can point directly to this file - tokenizer = BertTokenizer.from_pretrained("./test/saved_model/my_vocab.txt") - - # You can link tokens to special vocabulary when instantiating - tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased", unk_token="") - # You should be sure '' is in the vocabulary when doing that. - # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead) - assert tokenizer.unk_token == "" - ```""" - resume_download = kwargs.pop("resume_download", None) - proxies = kwargs.pop("proxies", None) - use_auth_token = kwargs.pop("use_auth_token", None) - subfolder = kwargs.pop("subfolder", None) - from_pipeline = kwargs.pop("_from_pipeline", None) - from_auto_class = kwargs.pop("_from_auto", False) - commit_hash = kwargs.pop("_commit_hash", None) - gguf_file = kwargs.get("gguf_file", None) - - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) - token = use_auth_token - - user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__} - if from_pipeline is not None: - user_agent["using_pipeline"] = from_pipeline - - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - vocab_files = {} - init_configuration = {} - - is_local = os.path.isdir(pretrained_model_name_or_path) - single_file_id = None - if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): - if len(cls.vocab_files_names) > 1 and not gguf_file: - raise ValueError( - f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " - "supported for this tokenizer. Use a model identifier or the path to a directory instead." - ) - warnings.warn( - f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and " - "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.", - FutureWarning, - ) - file_id = list(cls.vocab_files_names.keys())[0] - - vocab_files[file_id] = pretrained_model_name_or_path - single_file_id = file_id - else: - if gguf_file: - vocab_files["vocab_file"] = gguf_file - else: - # At this point pretrained_model_name_or_path is either a directory or a model identifier name - additional_files_names = { - "added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy - "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy - "tokenizer_config_file": TOKENIZER_CONFIG_FILE, - # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders - "tokenizer_file": FULL_TOKENIZER_FILE, - "chat_template_file": CHAT_TEMPLATE_FILE, - } - - vocab_files = {**cls.vocab_files_names, **additional_files_names} - if "tokenizer_file" in vocab_files: - # Try to get the tokenizer config to see if there are versioned tokenizer files. - fast_tokenizer_file = FULL_TOKENIZER_FILE - - try: - resolved_config_file = cached_file( - pretrained_model_name_or_path, - TOKENIZER_CONFIG_FILE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - subfolder=subfolder, - user_agent=user_agent, - _raise_exceptions_for_missing_entries=False, - _commit_hash=commit_hash, - ) - except OSError: - # Re-raise any error raised by cached_file in order to get a helpful error message - raise - except Exception: - # For any other exception, we throw a generic error. - raise OSError( - f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing all relevant files for a {cls.__name__} tokenizer." - ) - - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) - if resolved_config_file is not None: - with open(resolved_config_file, encoding="utf-8") as reader: - tokenizer_config = json.load(reader) - if "fast_tokenizer_files" in tokenizer_config: - fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) - vocab_files["tokenizer_file"] = fast_tokenizer_file - - # This block looks for any extra chat template files - if is_local: - template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR) - if template_dir.is_dir(): - for template_file in template_dir.glob("*.jinja"): - template_name = template_file.name.removesuffix(".jinja") - vocab_files[ - f"chat_template_{template_name}" - ] = f"{CHAT_TEMPLATE_DIR}/{template_file.name}" - else: - for template in list_repo_templates( - pretrained_model_name_or_path, - local_files_only=local_files_only, - revision=revision, - cache_dir=cache_dir, - ): - vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja" - - # Get files from url, cache, or disk depending on the case - resolved_vocab_files = {} - for file_id, file_path in vocab_files.items(): - if file_path is None: - resolved_vocab_files[file_id] = None - elif single_file_id == file_id: - if os.path.isfile(file_path): - resolved_vocab_files[file_id] = file_path - elif is_remote_url(file_path): - resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies) - else: - try: - resolved_vocab_files[file_id] = cached_file( - pretrained_model_name_or_path, - file_path, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _raise_exceptions_for_missing_entries=False, - _commit_hash=commit_hash, - ) - except OSError: - # Re-raise any error raised by cached_file in order to get a helpful error message - raise - except Exception: - # For any other exception, we throw a generic error. - raise OSError( - f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing all relevant files for a {cls.__name__} tokenizer." - ) - commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash) - - for file_id, file_path in vocab_files.items(): - if file_id not in resolved_vocab_files: - continue - - if is_local: - logger.info(f"loading file {file_path}") - else: - logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") - - return cls._from_pretrained( - resolved_vocab_files, - pretrained_model_name_or_path, - init_configuration, - *init_inputs, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - _commit_hash=commit_hash, - _is_local=is_local, - trust_remote_code=trust_remote_code, - **kwargs, - ) - - @classmethod - def _from_pretrained( - cls, - resolved_vocab_files, - pretrained_model_name_or_path, - init_configuration, - *init_inputs, - token=None, - cache_dir=None, - local_files_only=False, - _commit_hash=None, - _is_local=False, - trust_remote_code=False, - **kwargs, - ): - # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json - # file or if `from_slow` is set to True. - from_slow = kwargs.get("from_slow", False) - gguf_file = kwargs.get("gguf_file", None) - has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None - - # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be - # loaded directly from the GGUF file. - if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file: - slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained( - copy.deepcopy(resolved_vocab_files), - pretrained_model_name_or_path, - copy.deepcopy(init_configuration), - *init_inputs, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - _commit_hash=_commit_hash, - **(copy.deepcopy(kwargs)), - ) - else: - slow_tokenizer = None - - # Prepare tokenizer initialization kwargs - # Did we saved some inputs and kwargs to reload ? - tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) - if tokenizer_config_file is not None: - with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: - init_kwargs = json.load(tokenizer_config_handle) - # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. - config_tokenizer_class = init_kwargs.get("tokenizer_class") - init_kwargs.pop("tokenizer_class", None) - if not has_tokenizer_file: - init_kwargs.pop("tokenizer_file", None) - saved_init_inputs = init_kwargs.pop("init_inputs", ()) - if not init_inputs: - init_inputs = saved_init_inputs - else: - config_tokenizer_class = None - init_kwargs = init_configuration - - # If independent chat template file(s) exist, they take priority over template entries in the tokenizer config - chat_templates = {} - chat_template_file = resolved_vocab_files.pop("chat_template_file", None) - extra_chat_templates = [key for key in resolved_vocab_files if key.startswith("chat_template_")] - if chat_template_file is not None: - with open(chat_template_file, encoding="utf-8") as chat_template_handle: - chat_templates["default"] = chat_template_handle.read() - for extra_chat_template in extra_chat_templates: - template_file = resolved_vocab_files.pop(extra_chat_template, None) - if template_file is None: - continue # I think this should never happen, but just in case - template_name = extra_chat_template.removeprefix("chat_template_") - with open(template_file) as chat_template_handle: - chat_templates[template_name] = chat_template_handle.read() - if len(chat_templates) == 1 and "default" in chat_templates: - init_kwargs["chat_template"] = chat_templates["default"] - elif chat_templates: - init_kwargs["chat_template"] = chat_templates - - if not _is_local: - if "auto_map" in init_kwargs: - # For backward compatibility with odl format. - if isinstance(init_kwargs["auto_map"], (tuple, list)): - init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]} - - if config_tokenizer_class is None: - # Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo. - # If not, it raises a warning, but otherwise continues. Since we mostly load tokenizers with - # AutoTokenizer these days, it seems like a lot of work (and a source of bugs) for little gain. - # Maybe we can just remove this entirely? - from .models.auto.configuration_auto import AutoConfig # tests_ignore - - # Second attempt. If we have not yet found tokenizer_class, let's try to use the config. - try: - config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - trust_remote_code=trust_remote_code, - _commit_hash=_commit_hash, - ) - config_tokenizer_class = config.tokenizer_class - except (OSError, ValueError, KeyError): - # skip if an error occurred. - config = None - if config_tokenizer_class is None: - # Third attempt. If we have not yet found the original type of the tokenizer, - # we are loading we see if we can infer it from the type of the configuration file - from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore - - if hasattr(config, "model_type"): - model_type = config.model_type - else: - # Fallback: use pattern matching on the string. - model_type = None - for pattern in TOKENIZER_MAPPING_NAMES.keys(): - if pattern in str(pretrained_model_name_or_path): - model_type = pattern - break - - if model_type is not None: - config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES.get( - model_type, (None, None) - ) - if config_tokenizer_class is None: - config_tokenizer_class = config_tokenizer_class_fast - - if config_tokenizer_class is not None: - if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): - logger.warning( - "The tokenizer class you load from this checkpoint is not the same type as the class this" - " function is called from. It may result in unexpected tokenization. \nThe tokenizer class you" - f" load from this checkpoint is '{config_tokenizer_class}'. \nThe class this function is called" - f" from is '{cls.__name__}'." - ) - - # Update with newly provided kwargs - init_kwargs.update(kwargs) - - # Merge resolved_vocab_files arguments in init_kwargs. - added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None) - special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None) - for args_name, file_path in resolved_vocab_files.items(): - if args_name not in init_kwargs: - init_kwargs[args_name] = file_path - tokenizer_file = resolved_vocab_files.pop("tokenizer_file", None) - - if slow_tokenizer is not None: - init_kwargs["__slow_tokenizer"] = slow_tokenizer - init_kwargs["name_or_path"] = pretrained_model_name_or_path - - # Handle tokenizer serialization of added and special tokens - added_tokens_decoder: dict[int, AddedToken] = {} - added_tokens_map: dict[str, AddedToken] = {} - # if we have info on the slow added tokens - if "added_tokens_decoder" in init_kwargs: - for idx, token in init_kwargs["added_tokens_decoder"].items(): - if isinstance(token, dict): - token = AddedToken(**token) - if isinstance(token, AddedToken): - added_tokens_decoder[int(idx)] = token - added_tokens_map[str(token)] = token - else: - raise TypeError( - f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance" - ) - else: - # begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified - if special_tokens_map_file is not None: - with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: - special_tokens_map = json.load(special_tokens_map_handle) - for key, value in special_tokens_map.items(): - if key in kwargs and kwargs[key]: - # This value has already been redefined by the kwargs - # We keep this new value and ignore the one stored in the special_tokens_map_file - continue - if isinstance(value, dict): - value["special"] = True - value = AddedToken(**value) - elif key == "additional_special_tokens" and isinstance(value, list): - additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or [] - for token in value: - if isinstance(token, dict): - token["special"] = True - token = AddedToken(**token) - if token not in additional_special_tokens: - additional_special_tokens.append(token) - value = additional_special_tokens - init_kwargs[key] = value - - # slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`. - # this is for legacy purpose. We don't add the tokens after init for efficiency. - if added_tokens_file is not None: - special_tokens = [] - for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys(): - if init_kwargs[key] is not None: - if key == "additional_special_tokens": - special_tokens += [str(token) for token in init_kwargs[key]] - else: - special_tokens.append(str(init_kwargs[key])) - - with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: - added_tok_encoder = json.load(added_tokens_handle) - for str_token, index in added_tok_encoder.items(): - # if index not in added_tokens_decoder and str_token not in added_tokens_map: - special = str_token in special_tokens - added_tokens_decoder[index] = AddedToken( - str_token, rstrip=False, lstrip=False, normalized=not special, special=special - ) - added_tokens_map[str(token)] = added_tokens_decoder[index] - - # allows converting a fast -> slow: add the `tokenizer.json`'s `"added_tokens"` to the slow tokenizer - # if `tokenizer_config.json` is `None` - if tokenizer_file is not None: - # This is for slow so can be done before - with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle: - tokenizer_file_handle = json.load(tokenizer_file_handle) - added_tokens = tokenizer_file_handle.pop("added_tokens") - for serialized_tokens in added_tokens: - idx = serialized_tokens.pop("id") - added_tokens_decoder[idx] = AddedToken(**serialized_tokens) - added_tokens_map[str(added_tokens_decoder[idx])] = added_tokens_decoder[idx] - # end legacy - - # Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken - # convert {'__type': 'AddedToken', 'content': '', 'lstrip': False, 'normalized': True, ...} to AddedTokens - init_kwargs["added_tokens_decoder"] = added_tokens_decoder - init_kwargs = cls.convert_added_tokens(init_kwargs, save=False) - for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys(): - if added_tokens_map != {} and init_kwargs[key] is not None: - if key != "additional_special_tokens": - init_kwargs[key] = added_tokens_map.get(str(init_kwargs[key]), init_kwargs[key]) - - # Instantiate the tokenizer. - try: - tokenizer = cls(*init_inputs, **init_kwargs) - except import_protobuf_decode_error(): - logger.info( - "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." - "(Google protobuf error: Tried to load SPM model with non-SPM vocab file).", - ) - return False - except RuntimeError as e: - if "sentencepiece_processor.cc" in str(e): - logger.info( - "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." - "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).", - ) - return False - except OSError: - raise OSError( - "Unable to load vocabulary from file. " - "Please check that the provided vocabulary is accessible and not corrupted." - ) - - if added_tokens_decoder != {} and max(list(added_tokens_decoder.keys())[-1], 0) > tokenizer.vocab_size: - logger.info( - "Special tokens have been added in the vocabulary, make sure the associated word embeddings are" - " fine-tuned or trained." - ) - return tokenizer - - @staticmethod - def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): - # This method should be deleted in Transformers v5 - # Its only purpose is to potentially throw a warning - # that incorrectly defined max lengths of T5's tokenizer are used - # which we will correct in Transformers v5. - return max_model_length - - @classmethod - def convert_added_tokens(cls, obj: Union[AddedToken, Any], save=False, add_type_field=True): - if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken": - obj.pop("__type") - return AddedToken(**obj) - if isinstance(obj, AddedToken) and save: - obj = obj.__getstate__() - if add_type_field: - obj["__type"] = "AddedToken" - else: - # Don't save "special" for previous tokenizers - obj.pop("special") - return obj - elif isinstance(obj, (list, tuple)): - return [cls.convert_added_tokens(o, save=save, add_type_field=add_type_field) for o in obj] - elif isinstance(obj, dict): - return {k: cls.convert_added_tokens(v, save=save, add_type_field=add_type_field) for k, v in obj.items()} - return obj - - def save_chat_templates( - self, - save_directory: Union[str, os.PathLike], - tokenizer_config: dict, - filename_prefix: Optional[str], - save_jinja_files: bool, - ): - """ - Writes chat templates out to the save directory if we're using the new format, and removes them from - the tokenizer config if present. If we're using the legacy format, it doesn't write any files, and instead - writes the templates to the tokenizer config in the correct format. - """ - chat_template_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE - ) - chat_template_dir = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_DIR - ) - - saved_raw_chat_template_files = [] - if save_jinja_files and isinstance(self.chat_template, str): - # New format for single templates is to save them as chat_template.jinja - with open(chat_template_file, "w", encoding="utf-8") as f: - f.write(self.chat_template) - logger.info(f"chat template saved in {chat_template_file}") - saved_raw_chat_template_files.append(chat_template_file) - if "chat_template" in tokenizer_config: - tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too - elif save_jinja_files and isinstance(self.chat_template, dict): - # New format for multiple templates is to save the default as chat_template.jinja - # and the other templates in the chat_templates/ directory - for template_name, template in self.chat_template.items(): - if template_name == "default": - with open(chat_template_file, "w", encoding="utf-8") as f: - f.write(self.chat_template["default"]) - logger.info(f"chat template saved in {chat_template_file}") - saved_raw_chat_template_files.append(chat_template_file) - else: - Path(chat_template_dir).mkdir(exist_ok=True) - template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja") - with open(template_filepath, "w", encoding="utf-8") as f: - f.write(template) - logger.info(f"chat template saved in {template_filepath}") - saved_raw_chat_template_files.append(template_filepath) - if "chat_template" in tokenizer_config: - tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too - elif isinstance(self.chat_template, dict): - # Legacy format for multiple templates: - # chat template dicts are saved to the config as lists of dicts with fixed key names. - tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] - elif self.chat_template is not None: - # Legacy format for single templates: Just make them a key in tokenizer_config.json - tokenizer_config["chat_template"] = self.chat_template - return tokenizer_config, saved_raw_chat_template_files - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - legacy_format: Optional[bool] = None, - filename_prefix: Optional[str] = None, - push_to_hub: bool = False, - **kwargs, - ) -> tuple[str]: - """ - Save the full tokenizer state. - - - This method make sure the full tokenizer can then be re-loaded using the - [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] class method.. - - Warning,None This won't save modifications you may have applied to the tokenizer after the instantiation (for - instance, modifying `tokenizer.do_lower_case` after creation). - - Args: - save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved. - legacy_format (`bool`, *optional*): - Only applicable for a fast tokenizer. If unset (default), will save the tokenizer in the unified JSON - format as well as in legacy format if it exists, i.e. with tokenizer specific vocabulary and a separate - added_tokens files. - - If `False`, will only save the tokenizer in the unified JSON format. This format is incompatible with - "slow" tokenizers (not powered by the *tokenizers* library), so the tokenizer will not be able to be - loaded in the corresponding "slow" tokenizer. - - If `True`, will save the tokenizer in legacy format. If the "slow" tokenizer doesn't exits, a value - error is raised. - filename_prefix (`str`, *optional*): - A prefix to add to the names of the files saved by the tokenizer. - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the - repository you want to push to with `repo_id` (will default to the name of `save_directory` in your - namespace). - kwargs (`dict[str, Any]`, *optional*): - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. - - Returns: - A tuple of `str`: The files saved. - """ - use_auth_token = kwargs.pop("use_auth_token", None) - - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if kwargs.get("token", None) is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) - kwargs["token"] = use_auth_token - - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - os.makedirs(save_directory, exist_ok=True) - - if push_to_hub: - commit_message = kwargs.pop("commit_message", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id = self._create_repo(repo_id, **kwargs) - files_timestamps = self._get_files_timestamps(save_directory) - - special_tokens_map_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE - ) - tokenizer_config_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE - ) - - tokenizer_config = copy.deepcopy(self.init_kwargs) - - # Let's save the init kwargs - target_keys = set(self.init_kwargs.keys()) - # Let's save the special tokens map (only the strings) - target_keys.update(["model_max_length", "clean_up_tokenization_spaces"]) - - for k in target_keys: - if hasattr(self, k): - tokenizer_config[k] = getattr(self, k) - - # Let's make sure we properly save the special tokens - tokenizer_config.update(self.special_tokens_map) - if "extra_special_tokens" not in tokenizer_config: - tokenizer_config["extra_special_tokens"] = self.extra_special_tokens - tokenizer_config.update(self.extra_special_tokens) - - save_jinja_files = kwargs.get("save_jinja_files", True) - tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates( - save_directory, tokenizer_config, filename_prefix, save_jinja_files - ) - - if len(self.init_inputs) > 0: - tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) - for file_id in self.vocab_files_names.keys(): - tokenizer_config.pop(file_id, None) - - # no typefields, this way old fast and slow can load it - tokenizer_config = self.convert_added_tokens(tokenizer_config, add_type_field=True, save=True) - - # Process added tokens separately: allows previous versions to ignore it! - added_tokens = {} - for key, value in self.added_tokens_decoder.items(): - added_tokens[key] = value.__getstate__() - tokenizer_config["added_tokens_decoder"] = added_tokens - - # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained - tokenizer_class = self.__class__.__name__ - # Remove the Fast at the end if we can save the slow tokenizer - if tokenizer_class.endswith("Fast") and getattr(self, "can_save_slow_tokenizer", False): - tokenizer_class = tokenizer_class[:-4] - tokenizer_config["tokenizer_class"] = tokenizer_class - if getattr(self, "_auto_map", None) is not None: - tokenizer_config["auto_map"] = self._auto_map - if getattr(self, "_processor_class", None) is not None: - tokenizer_config["processor_class"] = self._processor_class - - # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be - # loaded from the Hub. - if self._auto_class is not None: - custom_object_save(self, save_directory, config=tokenizer_config) - - # remove private information - if "name_or_path" in tokenizer_config: - tokenizer_config.pop("name_or_path") - tokenizer_config.pop("special_tokens_map_file", None) - tokenizer_config.pop("tokenizer_file", None) - if "device_map" in tokenizer_config: - tokenizer_config.pop("device_map") - - with open(tokenizer_config_file, "w", encoding="utf-8") as f: - out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n" - f.write(out_str) - logger.info(f"tokenizer config file saved in {tokenizer_config_file}") - - # Sanitize AddedTokens in special_tokens_map - - # kept for forward compatibility, will be removed in transoformers 5. Typefields are not saved for FC, special should not be save either - write_dict = self.convert_added_tokens(self.special_tokens_map_extended, save=True, add_type_field=False) - with open(special_tokens_map_file, "w", encoding="utf-8") as f: - out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n" - f.write(out_str) - logger.info(f"Special tokens file saved in {special_tokens_map_file}") - - file_names = (tokenizer_config_file, special_tokens_map_file, *saved_raw_chat_template_files) - - save_files = self._save_pretrained( - save_directory=save_directory, - file_names=file_names, - legacy_format=legacy_format, - filename_prefix=filename_prefix, - ) - - if push_to_hub: - self._upload_modified_files( - save_directory, - repo_id, - files_timestamps, - commit_message=commit_message, - token=kwargs.get("token"), - ) - - return save_files - - def _save_pretrained( - self, - save_directory: Union[str, os.PathLike], - file_names: tuple[str], - legacy_format: Optional[bool] = None, - filename_prefix: Optional[str] = None, - ) -> tuple[str]: - """ - Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens. - - Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the - specific [`~tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained`] - """ - if legacy_format is False: - raise ValueError( - "Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format." - ) - - save_directory = str(save_directory) - - added_tokens_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE - ) - # the new get_added_vocab() also returns special tokens and tokens that have an index < vocab_size - added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size} - if added_vocab: - with open(added_tokens_file, "w", encoding="utf-8") as f: - out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n" - f.write(out_str) - logger.info(f"added tokens file saved in {added_tokens_file}") - - vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) - - return file_names + vocab_files + (added_tokens_file,) - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: - """ - Save only the vocabulary of the tokenizer (vocabulary + added tokens). - - This method won't save the configuration and special token mappings of the tokenizer. Use - [`~PreTrainedTokenizerFast._save_pretrained`] to save the whole state of the tokenizer. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - filename_prefix (`str`, *optional*): - An optional prefix to add to the named of the saved files. - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - raise NotImplementedError - - def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]: - """ - Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`. - - Args: - text (`str`): - The sequence to be encoded. - pair (`str`, *optional*): - A second sequence to be encoded with the first. - add_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to add the special tokens associated with the corresponding model. - kwargs (additional keyword arguments, *optional*): - Will be passed to the underlying model specific encode method. See details in - [`~PreTrainedTokenizerBase.__call__`] - - Returns: - `list[str]`: The list of tokens. - """ - raise NotImplementedError - - @add_end_docstrings( - ENCODE_KWARGS_DOCSTRING, - """ - **kwargs: Passed along to the `.tokenize()` method. - """, - """ - Returns: - `list[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text. - """, - ) - def encode( - self, - text: Union[TextInput, PreTokenizedInput, EncodedInput], - text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy, None] = None, - max_length: Optional[int] = None, - stride: int = 0, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, - ) -> list[int]: - """ - Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. - - Same as doing `self.convert_tokens_to_ids(self.tokenize(text))`. - - Args: - text (`str`, `list[str]` or `list[int]`): - The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the - `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` - method). - text_pair (`str`, `list[str]` or `list[int]`, *optional*): - Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using - the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` - method). - """ - encoded_inputs = self.encode_plus( - text, - text_pair=text_pair, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - padding_side=padding_side, - return_tensors=return_tensors, - **kwargs, - ) - - return encoded_inputs["input_ids"] - - def num_special_tokens_to_add(self, pair: bool = False) -> int: - raise NotImplementedError - - def _get_padding_truncation_strategies( - self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs - ): - """ - Find the correct padding/truncation strategy - """ - - # Backward compatibility for previous behavior, maybe we should deprecate it: - # If you only set max_length, it activates truncation for max_length - if max_length is not None and padding is False and truncation is None: - if verbose: - if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False): - logger.warning( - "Truncation was not explicitly activated but `max_length` is provided a specific value, please" - " use `truncation=True` to explicitly truncate examples to max length. Defaulting to" - " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the" - " tokenizer you can select this strategy more precisely by providing a specific strategy to" - " `truncation`." - ) - self.deprecation_warnings["Truncation-not-explicitly-activated"] = True - truncation = "longest_first" - - # Get padding strategy - if padding is not False: - if padding is True: - if verbose: - if max_length is not None and ( - truncation is None or truncation is False or truncation == "do_not_truncate" - ): - warnings.warn( - "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " - "To pad to max length, use `padding='max_length'`." - ) - padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch - elif not isinstance(padding, PaddingStrategy): - padding_strategy = PaddingStrategy(padding) - elif isinstance(padding, PaddingStrategy): - padding_strategy = padding - else: - padding_strategy = PaddingStrategy.DO_NOT_PAD - - # Get truncation strategy - if truncation is not False and truncation is not None: - if truncation is True: - truncation_strategy = ( - TruncationStrategy.LONGEST_FIRST - ) # Default to truncate the longest sequences in pairs of inputs - elif not isinstance(truncation, TruncationStrategy): - truncation_strategy = TruncationStrategy(truncation) - elif isinstance(truncation, TruncationStrategy): - truncation_strategy = truncation - else: - truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE - - # Set max length if needed - if max_length is None: - if padding_strategy == PaddingStrategy.MAX_LENGTH: - if self.model_max_length > LARGE_INTEGER: - if verbose: - if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False): - logger.warning( - "Asking to pad to max_length but no maximum length is provided and the model has no" - " predefined maximum length. Default to no padding." - ) - self.deprecation_warnings["Asking-to-pad-to-max_length"] = True - padding_strategy = PaddingStrategy.DO_NOT_PAD - else: - max_length = self.model_max_length - - if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: - if self.model_max_length > LARGE_INTEGER: - if verbose: - if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False): - logger.warning( - "Asking to truncate to max_length but no maximum length is provided and the model has" - " no predefined maximum length. Default to no truncation." - ) - self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True - truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE - else: - max_length = self.model_max_length - - # Test if we have a padding token - if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0): - raise ValueError( - "Asking to pad but the tokenizer does not have a padding token. " - "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " - "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." - ) - - # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided - if ( - truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE - and padding_strategy != PaddingStrategy.DO_NOT_PAD - and pad_to_multiple_of is not None - and max_length is not None - and (max_length % pad_to_multiple_of != 0) - ): - raise ValueError( - "Truncation and padding are both activated but " - f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." - ) - - return padding_strategy, truncation_strategy, max_length, kwargs - - @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) - def __call__( - self, - text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None, - text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, - text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None, - text_pair_target: Optional[ - Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] - ] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy, None] = None, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - **kwargs, - ) -> BatchEncoding: - """ - Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of - sequences. - - Args: - text (`str`, `list[str]`, `list[list[str]]`, *optional*): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - text_pair (`str`, `list[str]`, `list[list[str]]`, *optional*): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - text_target (`str`, `list[str]`, `list[list[str]]`, *optional*): - The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a - list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), - you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - text_pair_target (`str`, `list[str]`, `list[list[str]]`, *optional*): - The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a - list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), - you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - """ - # To avoid duplicating - all_kwargs = { - "add_special_tokens": add_special_tokens, - "padding": padding, - "truncation": truncation, - "max_length": max_length, - "stride": stride, - "is_split_into_words": is_split_into_words, - "pad_to_multiple_of": pad_to_multiple_of, - "padding_side": padding_side, - "return_tensors": return_tensors, - "return_token_type_ids": return_token_type_ids, - "return_attention_mask": return_attention_mask, - "return_overflowing_tokens": return_overflowing_tokens, - "return_special_tokens_mask": return_special_tokens_mask, - "return_offsets_mapping": return_offsets_mapping, - "return_length": return_length, - "split_special_tokens": kwargs.pop("split_special_tokens", self.split_special_tokens), - "verbose": verbose, - } - - if return_tensors in ("tf", "jax"): - logger.warning_once( - "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " - "recommend migrating to PyTorch classes or pinning your version of Transformers." - ) - all_kwargs.update(kwargs) - if text is None and text_target is None: - raise ValueError("You need to specify either `text` or `text_target`.") - if text is not None: - # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the - # input mode in this case. - if not self._in_target_context_manager: - self._switch_to_input_mode() - encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs) - if text_target is not None: - self._switch_to_target_mode() - target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs) - # Leave back tokenizer in input mode - self._switch_to_input_mode() - - if text_target is None: - return encodings - elif text is None: - return target_encodings - else: - encodings["labels"] = target_encodings["input_ids"] - return encodings - - def _call_one( - self, - text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]], - text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy, None] = None, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - split_special_tokens: bool = False, - **kwargs, - ) -> BatchEncoding: - # Input type checking for clearer error - def _is_valid_text_input(t): - if isinstance(t, str): - # Strings are fine - return True - elif isinstance(t, (list, tuple)): - # List are fine as long as they are... - if len(t) == 0: - # ... empty - return True - elif isinstance(t[0], str): - # ... list of strings - return True - elif isinstance(t[0], (list, tuple)): - # ... list with an empty list or with a list of strings - return len(t[0]) == 0 or isinstance(t[0][0], str) - else: - return False - else: - return False - - if not _is_valid_text_input(text): - raise ValueError( - "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) " - "or `list[list[str]]` (batch of pretokenized examples)." - ) - - if text_pair is not None and not _is_valid_text_input(text_pair): - raise ValueError( - "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) " - "or `list[list[str]]` (batch of pretokenized examples)." - ) - - if is_split_into_words: - is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) - else: - is_batched = isinstance(text, (list, tuple)) - - if is_batched: - if isinstance(text_pair, str): - raise TypeError( - "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as" - " `text`." - ) - if text_pair is not None and len(text) != len(text_pair): - raise ValueError( - f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: " - f" {len(text_pair)}." - ) - batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text - return self.batch_encode_plus( - batch_text_or_text_pairs=batch_text_or_text_pairs, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - split_special_tokens=split_special_tokens, - **kwargs, - ) - else: - return self.encode_plus( - text=text, - text_pair=text_pair, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - split_special_tokens=split_special_tokens, - **kwargs, - ) - - @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) - def encode_plus( - self, - text: Union[TextInput, PreTokenizedInput, EncodedInput], - text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy, None] = None, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - **kwargs, - ) -> BatchEncoding: - """ - Tokenize and prepare for the model a sequence or a pair of sequences. - - - - This method is deprecated, `__call__` should be used instead. - - - - Args: - text (`str`, `list[str]` or (for non-fast tokenizers) `list[int]`): - The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the - `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` - method). - text_pair (`str`, `list[str]` or `list[int]`, *optional*): - Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using - the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` - method). - """ - - padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - **kwargs, - ) - - return self._encode_plus( - text=text, - text_pair=text_pair, - add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, - max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - split_special_tokens=kwargs.pop("split_special_tokens", self.split_special_tokens), - **kwargs, - ) - - def _encode_plus( - self, - text: Union[TextInput, PreTokenizedInput, EncodedInput], - text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, - add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - split_special_tokens: bool = False, - **kwargs, - ) -> BatchEncoding: - raise NotImplementedError - - @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) - def batch_encode_plus( - self, - batch_text_or_text_pairs: Union[ - list[TextInput], - list[TextInputPair], - list[PreTokenizedInput], - list[PreTokenizedInputPair], - list[EncodedInput], - list[EncodedInputPair], - ], - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy, None] = None, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - split_special_tokens: bool = False, - **kwargs, - ) -> BatchEncoding: - """ - Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. - - - - This method is deprecated, `__call__` should be used instead. - - - - Args: - batch_text_or_text_pairs (`list[str]`, `list[tuple[str, str]]`, `list[list[str]]`, \ - `list[tuple[list[str], list[str]]]`, and for not-fast tokenizers, also `list[list[int]]`, `list[tuple[list[int], list[int]]]`): - Batch of sequences or pair of sequences to be encoded. This can be a list of - string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see - details in `encode_plus`). - """ - - # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' - padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - **kwargs, - ) - - return self._batch_encode_plus( - batch_text_or_text_pairs=batch_text_or_text_pairs, - add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, - max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - split_special_tokens=split_special_tokens, - **kwargs, - ) - - def _batch_encode_plus( - self, - batch_text_or_text_pairs: Union[ - list[TextInput], - list[TextInputPair], - list[PreTokenizedInput], - list[PreTokenizedInputPair], - list[EncodedInput], - list[EncodedInputPair], - ], - add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - split_special_tokens: bool = False, - **kwargs, - ) -> BatchEncoding: - raise NotImplementedError - - def pad( - self, - encoded_inputs: Union[ - BatchEncoding, - list[BatchEncoding], - dict[str, EncodedInput], - dict[str, list[EncodedInput]], - list[dict[str, EncodedInput]], - ], - padding: Union[bool, str, PaddingStrategy] = True, - max_length: Optional[int] = None, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_attention_mask: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - verbose: bool = True, - ) -> BatchEncoding: - """ - Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length - in the batch. - - Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`, - `self.pad_token_id` and `self.pad_token_type_id`). - - Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the - text followed by a call to the `pad` method to get a padded encoding. - - - - If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the - result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of - PyTorch tensors, you will lose the specific device of your tensors however. - - - - Args: - encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `dict[str, list[int]]`, `dict[str, list[list[int]]` or `list[dict[str, list[int]]]`): - Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of - tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str, - list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader - collate function. - - Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see - the note above for the return type. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta). - padding_side (`str`, *optional*): - The side on which the model should have padding applied. Should be selected between ['right', 'left']. - Default value is picked from the class attribute of the same name. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific tokenizer's default, defined by the `return_outputs` attribute. - - [What are attention masks?](../glossary#attention-mask) - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - verbose (`bool`, *optional*, defaults to `True`): - Whether or not to print more information and warnings. - """ - if self.__class__.__name__.endswith("Fast"): - if not self.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False): - logger.warning_advice( - f"You're using a {self.__class__.__name__} tokenizer. Please note that with a fast tokenizer," - " using the `__call__` method is faster than using a method to encode the text followed by a call" - " to the `pad` method to get a padded encoding." - ) - self.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True - - # If we have a list of dicts, let's convert it in a dict of lists - # We do this to allow using this method as a collate_fn function in PyTorch Dataloader - if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): - encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} - - # The model's main input name, usually `input_ids`, has been passed for padding - if self.model_input_names[0] not in encoded_inputs: - raise ValueError( - "You should supply an encoding or a list of encodings to this method " - f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" - ) - - required_input = encoded_inputs[self.model_input_names[0]] - - if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0): - if return_attention_mask: - encoded_inputs["attention_mask"] = [] - return encoded_inputs - - # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects - # and rebuild them afterwards if no return_tensors is specified - # Note that we lose the specific device the tensor may be on for PyTorch - - first_element = required_input[0] - if isinstance(first_element, (list, tuple)): - # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. - for item in required_input: - if len(item) != 0: - first_element = item[0] - break - # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. - if not isinstance(first_element, (int, list, tuple)): - if is_tf_tensor(first_element): - return_tensors = "tf" if return_tensors is None else return_tensors - elif is_torch_tensor(first_element): - return_tensors = "pt" if return_tensors is None else return_tensors - elif isinstance(first_element, np.ndarray): - return_tensors = "np" if return_tensors is None else return_tensors - else: - raise ValueError( - f"type of {first_element} unknown: {type(first_element)}. " - "Should be one of a python, numpy, pytorch or tensorflow object." - ) - - for key, value in encoded_inputs.items(): - encoded_inputs[key] = to_py_obj(value) - - # Convert padding_strategy in PaddingStrategy - padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( - padding=padding, max_length=max_length, verbose=verbose - ) - - required_input = encoded_inputs[self.model_input_names[0]] - if required_input and not isinstance(required_input[0], (list, tuple)): - encoded_inputs = self._pad( - encoded_inputs, - max_length=max_length, - padding_strategy=padding_strategy, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_attention_mask=return_attention_mask, - ) - return BatchEncoding(encoded_inputs, tensor_type=return_tensors) - - batch_size = len(required_input) - assert all( - len(v) == batch_size for v in encoded_inputs.values() - ), "Some items in the output dictionary have a different batch size than others." - - if padding_strategy == PaddingStrategy.LONGEST: - max_length = max(len(inputs) for inputs in required_input) - padding_strategy = PaddingStrategy.MAX_LENGTH - - batch_outputs = {} - for i in range(batch_size): - inputs = {k: v[i] for k, v in encoded_inputs.items()} - outputs = self._pad( - inputs, - max_length=max_length, - padding_strategy=padding_strategy, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_attention_mask=return_attention_mask, - ) - - for key, value in outputs.items(): - if key not in batch_outputs: - batch_outputs[key] = [] - batch_outputs[key].append(value) - - return BatchEncoding(batch_outputs, tensor_type=return_tensors) - - def create_token_type_ids_from_sequences( - self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None - ) -> list[int]: - """ - Create the token type IDs corresponding to the sequences passed. [What are token type - IDs?](../glossary#token-type-ids) - - Should be overridden in a subclass if the model has a special way of building those. - - Args: - token_ids_0 (`list[int]`): The first tokenized sequence. - token_ids_1 (`list[int]`, *optional*): The second tokenized sequence. - - Returns: - `list[int]`: The token type ids. - """ - cls_len = int(getattr(self, "cls_token_id", None) is not None) - sep_len = int(getattr(self, "sep_token_id", None) is not None) - - if token_ids_1 is None: - return [0] * (cls_len + len(token_ids_0) + sep_len) - - return [0] * (cls_len + len(token_ids_0) + sep_len) + [1] * (len(token_ids_1) + sep_len) - - def build_inputs_with_special_tokens( - self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None - ) -> list[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and - adding special tokens. - - This implementation does not add special tokens and this method should be overridden in a subclass. - - Args: - token_ids_0 (`list[int]`): The first tokenized sequence. - token_ids_1 (`list[int]`, *optional*): The second tokenized sequence. - - Returns: - `list[int]`: The model input with special tokens. - """ - if token_ids_1 is None: - return token_ids_0 - return token_ids_0 + token_ids_1 - - @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) - def prepare_for_model( - self, - ids: list[int], - pair_ids: Optional[list[int]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy, None] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - prepend_batch_axis: bool = False, - **kwargs, - ) -> BatchEncoding: - """ - Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It - adds special tokens, truncates sequences if overflowing while taking into account the special tokens and - manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids* - different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return - overflowing tokens. Such a combination of arguments will raise an error. - - Args: - ids (`list[int]`): - Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and - `convert_tokens_to_ids` methods. - pair_ids (`list[int]`, *optional*): - Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` - and `convert_tokens_to_ids` methods. - """ - - # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' - padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - **kwargs, - ) - - pair = bool(pair_ids is not None) - len_ids = len(ids) - len_pair_ids = len(pair_ids) if pair else 0 - - if return_token_type_ids and not add_special_tokens: - raise ValueError( - "Asking to return token_type_ids while setting add_special_tokens to False " - "results in an undefined behavior. Please set add_special_tokens to True or " - "set return_token_type_ids to None." - ) - - if ( - return_overflowing_tokens - and truncation_strategy == TruncationStrategy.LONGEST_FIRST - and pair_ids is not None - ): - raise ValueError( - "Not possible to return overflowing tokens for pair of sequences with the " - "`longest_first`. Please select another truncation strategy than `longest_first`, " - "for instance `only_second` or `only_first`." - ) - - # Load from model defaults - if return_token_type_ids is None: - return_token_type_ids = "token_type_ids" in self.model_input_names - if return_attention_mask is None: - return_attention_mask = "attention_mask" in self.model_input_names - - encoded_inputs = {} - - # Compute the total size of the returned encodings - total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) - - # Truncation: Handle max sequence length - overflowing_tokens = [] - if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: - ids, pair_ids, overflowing_tokens = self.truncate_sequences( - ids, - pair_ids=pair_ids, - num_tokens_to_remove=total_len - max_length, - truncation_strategy=truncation_strategy, - stride=stride, - ) - - if return_overflowing_tokens: - encoded_inputs["overflowing_tokens"] = overflowing_tokens - encoded_inputs["num_truncated_tokens"] = total_len - max_length - - # Add special tokens - if add_special_tokens: - sequence = self.build_inputs_with_special_tokens(ids, pair_ids) - token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) - else: - sequence = ids + pair_ids if pair else ids - token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) - - # Build output dictionary - encoded_inputs["input_ids"] = sequence - if return_token_type_ids: - encoded_inputs["token_type_ids"] = token_type_ids - if return_special_tokens_mask: - if add_special_tokens: - encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) - else: - encoded_inputs["special_tokens_mask"] = [0] * len(sequence) - - # Check lengths - self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) - - # Padding - if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: - encoded_inputs = self.pad( - encoded_inputs, - max_length=max_length, - padding=padding_strategy.value, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_attention_mask=return_attention_mask, - ) - - if return_length: - encoded_inputs["length"] = len(encoded_inputs["input_ids"]) - - batch_outputs = BatchEncoding(encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis) - - return batch_outputs - - def truncate_sequences( - self, - ids: list[int], - pair_ids: Optional[list[int]] = None, - num_tokens_to_remove: int = 0, - truncation_strategy: Union[str, TruncationStrategy] = "longest_first", - stride: int = 0, - ) -> tuple[list[int], list[int], list[int]]: - """ - Truncates a sequence pair in-place following the strategy. - - Args: - ids (`list[int]`): - Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and - `convert_tokens_to_ids` methods. - pair_ids (`list[int]`, *optional*): - Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` - and `convert_tokens_to_ids` methods. - num_tokens_to_remove (`int`, *optional*, defaults to 0): - Number of tokens to remove using the truncation strategy. - truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `'longest_first'`): - The strategy to follow for truncation. Can be: - - - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the - maximum acceptable input length for the model if that argument is not provided. This will truncate - token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a - batch of pairs) is provided. - - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the - maximum acceptable input length for the model if that argument is not provided. This will only - truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the - maximum acceptable input length for the model if that argument is not provided. This will only - truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater - than the model maximum admissible input size). - stride (`int`, *optional*, defaults to 0): - If set to a positive number, the overflowing tokens returned will contain some tokens from the main - sequence returned. The value of this argument defines the number of additional tokens. - - Returns: - `tuple[list[int], list[int], list[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of - overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair - of sequences (or a batch of pairs) is provided. - """ - if num_tokens_to_remove <= 0: - return ids, pair_ids, [] - - if not isinstance(truncation_strategy, TruncationStrategy): - truncation_strategy = TruncationStrategy(truncation_strategy) - - overflowing_tokens = [] - if truncation_strategy == TruncationStrategy.ONLY_FIRST or ( - truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None - ): - if len(ids) > num_tokens_to_remove: - window_len = min(len(ids), stride + num_tokens_to_remove) - if self.truncation_side == "left": - overflowing_tokens = ids[:window_len] - ids = ids[num_tokens_to_remove:] - elif self.truncation_side == "right": - overflowing_tokens = ids[-window_len:] - ids = ids[:-num_tokens_to_remove] - else: - raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.") - - else: - error_msg = ( - f"We need to remove {num_tokens_to_remove} to truncate the input " - f"but the first sequence has a length {len(ids)}. " - ) - if truncation_strategy == TruncationStrategy.ONLY_FIRST: - error_msg = ( - error_msg + "Please select another truncation strategy than " - f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." - ) - logger.error(error_msg) - elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: - logger.warning( - "Be aware, overflowing tokens are not returned for the setting you have chosen," - f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' " - "truncation strategy. So the returned list will always be empty even if some " - "tokens have been removed." - ) - len_pair_ids = len(pair_ids) if pair_ids is not None else 0 - len_ids = len(ids) - first_remove = min(abs(len_pair_ids - len_ids), num_tokens_to_remove) - second_remove = num_tokens_to_remove - first_remove - if len_ids > len_pair_ids: - ids_to_move = first_remove + second_remove // 2 - pair_ids_to_move = second_remove - second_remove // 2 - else: - ids_to_move = second_remove // 2 - pair_ids_to_move = first_remove + second_remove - (second_remove // 2) - - if self.truncation_side == "right": - ids = ids[:-ids_to_move] if ids_to_move > 0 else ids - pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids is not None and pair_ids_to_move > 0 else pair_ids - elif self.truncation_side == "left": - ids = ids[ids_to_move:] - pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None - else: - raise ValueError(f"invalid truncation strategy: {self.truncation_side}") - - elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: - if len(pair_ids) > num_tokens_to_remove: - window_len = min(len(pair_ids), stride + num_tokens_to_remove) - if self.truncation_side == "right": - overflowing_tokens = pair_ids[-window_len:] - pair_ids = pair_ids[:-num_tokens_to_remove] - elif self.truncation_side == "left": - overflowing_tokens = pair_ids[:window_len] - pair_ids = pair_ids[num_tokens_to_remove:] - else: - raise ValueError(f"invalid truncation strategy: {self.truncation_side}") - else: - logger.error( - f"We need to remove {num_tokens_to_remove} to truncate the input " - f"but the second sequence has a length {len(pair_ids)}. " - f"Please select another truncation strategy than {truncation_strategy}, " - "for instance 'longest_first' or 'only_first'." - ) - - return (ids, pair_ids, overflowing_tokens) - - def _pad( - self, - encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], - max_length: Optional[int] = None, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - pad_to_multiple_of: Optional[int] = None, - padding_side: Optional[str] = None, - return_attention_mask: Optional[bool] = None, - ) -> dict: - """ - Pad encoded inputs (on left/right and up to predefined length or max length in the batch) - - Args: - encoded_inputs: - Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`). - max_length: maximum length of the returned list and optionally padding length (see below). - Will truncate by taking into account the special tokens. - padding_strategy: PaddingStrategy to use for padding. - - - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in `padding_side` argument: - - - 'left': pads on the left of the sequences - - 'right': pads on the right of the sequences - pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. - This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability - `>= 7.5` (Volta). - padding_side: - The side on which the model should have padding applied. Should be selected between ['right', 'left']. - Default value is picked from the class attribute of the same name. - return_attention_mask: - (optional) Set to False to avoid returning attention mask (default: set to model specifics) - """ - # Load from model defaults - if return_attention_mask is None: - return_attention_mask = "attention_mask" in self.model_input_names - - required_input = encoded_inputs[self.model_input_names[0]] - - if padding_strategy == PaddingStrategy.LONGEST: - max_length = len(required_input) - - if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length - - # Initialize attention mask if not present. - if return_attention_mask and "attention_mask" not in encoded_inputs: - encoded_inputs["attention_mask"] = [1] * len(required_input) - - if needs_to_be_padded: - difference = max_length - len(required_input) - padding_side = padding_side if padding_side is not None else self.padding_side - - if padding_side == "right": - if return_attention_mask: - encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference - if "token_type_ids" in encoded_inputs: - encoded_inputs["token_type_ids"] = ( - encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference - ) - if "special_tokens_mask" in encoded_inputs: - encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference - encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference - elif padding_side == "left": - if return_attention_mask: - encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] - if "token_type_ids" in encoded_inputs: - encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ - "token_type_ids" - ] - if "special_tokens_mask" in encoded_inputs: - encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] - encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input - else: - raise ValueError(f"Invalid padding strategy: {padding_side}") - - return encoded_inputs - - def convert_tokens_to_string(self, tokens: list[str]) -> str: - """ - Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we - often want to remove sub-word tokenization artifacts at the same time. - - Args: - tokens (`list[str]`): The token to join in a string. - - Returns: - `str`: The joined tokens. - """ - raise NotImplementedError - - def batch_decode( - self, - sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: Optional[bool] = None, - **kwargs, - ) -> list[str]: - """ - Convert a list of lists of token ids into a list of strings by calling decode. - - Args: - sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`): - List of tokenized input ids. Can be obtained using the `__call__` method. - skip_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to remove special tokens in the decoding. - clean_up_tokenization_spaces (`bool`, *optional*): - Whether or not to clean up the tokenization spaces. If `None`, will default to - `self.clean_up_tokenization_spaces`. - kwargs (additional keyword arguments, *optional*): - Will be passed to the underlying model specific decode method. - - Returns: - `list[str]`: The list of decoded sentences. - """ - return [ - self.decode( - seq, - skip_special_tokens=skip_special_tokens, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - for seq in sequences - ] - - def decode( - self, - token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: Optional[bool] = None, - **kwargs, - ) -> str: - """ - Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special - tokens and clean up tokenization spaces. - - Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. - - Args: - token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): - List of tokenized input ids. Can be obtained using the `__call__` method. - skip_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to remove special tokens in the decoding. - clean_up_tokenization_spaces (`bool`, *optional*): - Whether or not to clean up the tokenization spaces. If `None`, will default to - `self.clean_up_tokenization_spaces`. - kwargs (additional keyword arguments, *optional*): - Will be passed to the underlying model specific decode method. - - Returns: - `str`: The decoded sentence. - """ - # Convert inputs to python lists - token_ids = to_py_obj(token_ids) - - return self._decode( - token_ids=token_ids, - skip_special_tokens=skip_special_tokens, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - - def _decode( - self, - token_ids: Union[int, list[int]], - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: Optional[bool] = None, - **kwargs, - ) -> str: - raise NotImplementedError - - def get_special_tokens_mask( - self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False - ) -> list[int]: - """ - Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. - - Args: - token_ids_0 (`list[int]`): - List of ids of the first sequence. - token_ids_1 (`list[int]`, *optional*): - List of ids of the second sequence. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - assert already_has_special_tokens and token_ids_1 is None, ( - "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " - "Please use a slow (full python) tokenizer to activate this argument. " - "Or set `return_special_tokens_mask=True` when calling the encoding method " - "to get the special tokens mask in any tokenizer. " - ) - - all_special_ids = self.all_special_ids # cache the property - - special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] - - return special_tokens_mask - - @staticmethod - def clean_up_tokenization(out_string: str) -> str: - """ - Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms. - - Args: - out_string (`str`): The text to clean up. - - Returns: - `str`: The cleaned-up string. - """ - out_string = ( - out_string.replace(" .", ".") - .replace(" ?", "?") - .replace(" !", "!") - .replace(" ,", ",") - .replace(" ' ", "'") - .replace(" n't", "n't") - .replace(" 'm", "'m") - .replace(" 's", "'s") - .replace(" 've", "'ve") - .replace(" 're", "'re") - ) - return out_string - - def _eventual_warn_about_too_long_sequence(self, ids: list[int], max_length: Optional[int], verbose: bool): - """ - Depending on the input and internal state we might trigger a warning about a sequence that is too long for its - corresponding model - - Args: - ids (`list[str]`): The ids produced by the tokenization - max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set) - verbose (`bool`): Whether or not to print more information and warnings. - - """ - if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0: - if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): - logger.warning( - "Token indices sequence length is longer than the specified maximum sequence length " - f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model " - "will result in indexing errors" - ) - self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True - - def _switch_to_input_mode(self): - """ - Private method to put the tokenizer in input mode (when it has different modes for input/outputs) - """ - pass - - def _switch_to_target_mode(self): - """ - Private method to put the tokenizer in target mode (when it has different modes for input/outputs) - """ - pass - - @contextmanager - def as_target_tokenizer(self): - """ - Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to - sequence-to-sequence models that need a slightly different processing for the labels. - """ - warnings.warn( - "`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your " - "labels by using the argument `text_target` of the regular `__call__` method (either in the same call as " - "your input texts if you use the same keyword arguments, or in a separate call." - ) - self._switch_to_target_mode() - self._in_target_context_manager = True - yield - self._in_target_context_manager = False - self._switch_to_input_mode() - - @classmethod - def register_for_auto_class(cls, auto_class="AutoTokenizer"): - """ - Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the - library are already mapped with `AutoTokenizer`. - - - - Args: - auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`): - The auto class to register this new tokenizer with. - """ - if not isinstance(auto_class, str): - auto_class = auto_class.__name__ - - import transformers.models.auto as auto_module - - if not hasattr(auto_module, auto_class): - raise ValueError(f"{auto_class} is not a valid auto class.") - - cls._auto_class = auto_class - - def prepare_seq2seq_batch( - self, - src_texts: list[str], - tgt_texts: Optional[list[str]] = None, - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - padding: str = "longest", - return_tensors: Optional[str] = None, - truncation: bool = True, - **kwargs, - ) -> BatchEncoding: - """ - Prepare model inputs for translation. For best performance, translate one sentence at a time. - - Arguments: - src_texts (`list[str]`): - List of documents to summarize or source language texts. - tgt_texts (`list`, *optional*): - List of summaries or target language texts. - max_length (`int`, *optional*): - Controls the maximum length for encoder inputs (documents to summarize or source language texts) If - left unset or set to `None`, this will use the predefined model maximum length if a maximum length is - required by one of the truncation/padding parameters. If the model has no specific maximum input length - (like XLNet) truncation/padding to a maximum length will be deactivated. - max_target_length (`int`, *optional*): - Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set - to `None`, this will use the max_length value. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Activates and controls padding. Accepts the following values: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `True`): - Activates and controls truncation. Accepts the following values: - - - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or - to the maximum acceptable input length for the model if that argument is not provided. This will - truncate token by token, removing a token from the longest sequence in the pair if a pair of - sequences (or a batch of pairs) is provided. - - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the - maximum acceptable input length for the model if that argument is not provided. This will only - truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the - maximum acceptable input length for the model if that argument is not provided. This will only - truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths - greater than the model maximum admissible input size). - **kwargs: - Additional keyword arguments passed along to `self.__call__`. - - Return: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: - - - **input_ids** -- List of token ids to be fed to the encoder. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **labels** -- List of token ids for tgt_texts. - - The full set of keys `[input_ids, attention_mask, labels]`, will only be returned if tgt_texts is passed. - Otherwise, input_ids, attention_mask will be the only keys. - """ - # docstyle-ignore - formatted_warning = """ -`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular -`__call__` method to prepare your inputs and targets. - -Here is a short example: - -model_inputs = tokenizer(src_texts, text_target=tgt_texts, ...) - -If you either need to use different keyword arguments for the source and target texts, you should do two calls like -this: - -model_inputs = tokenizer(src_texts, ...) -labels = tokenizer(text_target=tgt_texts, ...) -model_inputs["labels"] = labels["input_ids"] - -See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice. -For a more complete example, see the implementation of `prepare_seq2seq_batch`. -""" - warnings.warn(formatted_warning, FutureWarning) - # mBART-specific kwargs that should be ignored by other models. - kwargs.pop("src_lang", None) - kwargs.pop("tgt_lang", None) - if max_length is None: - max_length = self.model_max_length - model_inputs = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - with self.as_target_tokenizer(): - labels = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=truncation, - **kwargs, - ) - model_inputs["labels"] = labels["input_ids"] - return model_inputs - - -def get_fast_tokenizer_file(tokenization_files: list[str]) -> str: - """ - Get the tokenization file to use for this version of transformers. - - Args: - tokenization_files (`list[str]`): The list of available configuration files. - - Returns: - `str`: The tokenization file to use. - """ - tokenizer_files_map = {} - for file_name in tokenization_files: - search = _re_tokenizer_file.search(file_name) - if search is not None: - v = search.groups()[0] - tokenizer_files_map[v] = file_name - available_versions = sorted(tokenizer_files_map.keys()) - - # Defaults to FULL_TOKENIZER_FILE and then try to look at some newer versions. - tokenizer_file = FULL_TOKENIZER_FILE - transformers_version = version.parse(__version__) - for v in available_versions: - if version.parse(v) <= transformers_version: - tokenizer_file = tokenizer_files_map[v] - else: - # No point going further since the versions are sorted. - break - - return tokenizer_file - - -# To update the docstring, we need to copy the method, otherwise we change the original docstring. -PreTrainedTokenizerBase.push_to_hub = copy_func(PreTrainedTokenizerBase.push_to_hub) -if PreTrainedTokenizerBase.push_to_hub.__doc__ is not None: - PreTrainedTokenizerBase.push_to_hub.__doc__ = PreTrainedTokenizerBase.push_to_hub.__doc__.format( - object="tokenizer", object_class="AutoTokenizer", object_files="tokenizer files" - ) From 91609b956ef834af7d6dffb6fb2f913fbc948b31 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:31:29 +0800 Subject: [PATCH 43/94] resize stacked images one by one --- mindone/transformers/image_processing_utils_fast.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mindone/transformers/image_processing_utils_fast.py b/mindone/transformers/image_processing_utils_fast.py index 65b2b8eaba..2e3a5e19fa 100644 --- a/mindone/transformers/image_processing_utils_fast.py +++ b/mindone/transformers/image_processing_utils_fast.py @@ -677,8 +677,12 @@ def _preprocess( resized_images_grouped = {} for shape, stacked_images in grouped_images.items(): if do_resize: - stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) - resized_images_grouped[shape] = stacked_images + stacked_images_updated = [] + for i in range(len(stacked_images)): + stacked_images_updated.append( + self.resize(image=stacked_images[i], size=size, interpolation=interpolation) + ) + resized_images_grouped[shape] = stacked_images_updated resized_images = reorder_images(resized_images_grouped, grouped_images_index) # Group images by size for further processing From ffd337714892e02a80f4b4810579c7bc489b4448 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:36:48 +0800 Subject: [PATCH 44/94] remove torchvision decoders --- mindone/transformers/video_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mindone/transformers/video_utils.py b/mindone/transformers/video_utils.py index f8735c1ad2..ac9d540dd7 100644 --- a/mindone/transformers/video_utils.py +++ b/mindone/transformers/video_utils.py @@ -447,7 +447,6 @@ def sample_indices_fn(metadata, **kwargs): "decord": read_video_decord, "opencv": read_video_opencv, "pyav": read_video_pyav, - "torchvision": read_video_mindspore, "mindspore": read_video_mindspore, } From b38bf63775b586ccec5840e7f5909b6f49a633a3 Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:43:30 +0800 Subject: [PATCH 45/94] fix get_default_dtype bug --- mindone/transformers/modeling_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index ffecab0e53..bd52689553 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -441,6 +441,15 @@ def _get_mindspore_dtype( f'`mindspore_dtype` can be either `ms.Type` or `"auto"`, but received {mindspore_dtype}' ) # TODO: We cannot set default mindspore dtype! + else: + # set fp32 as the default dtype for BC + # TODO: We cannot get default mindspore dtype! Therefore, we set default dtype to ms.float32 + default_dtype = dtype_to_str(ms.float32) + config.mindspore_dtype = default_dtype + for key in config.sub_configs.keys(): + value = getattr(config, key) + value.mindspore_dtype = default_dtype + return config, mindspore_dtype From f32b7cb169a710d8a6957d9d92cd1d90450ab3f2 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:49:54 +0800 Subject: [PATCH 46/94] load module dynamically from mindone/transformers --- mindone/transformers/processing_utils.py | 35 +++++++++--------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/mindone/transformers/processing_utils.py b/mindone/transformers/processing_utils.py index 130f8b5100..841138fea3 100644 --- a/mindone/transformers/processing_utils.py +++ b/mindone/transformers/processing_utils.py @@ -1358,29 +1358,20 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs) @staticmethod def get_possibly_dynamic_module(module_name): - if hasattr(transformers_module, module_name): - return getattr(transformers_module, module_name) - lookup_locations = [ - transformers_module.IMAGE_PROCESSOR_MAPPING, - transformers_module.VIDEO_PROCESSOR_MAPPING, - transformers_module.TOKENIZER_MAPPING, - transformers_module.FEATURE_EXTRACTOR_MAPPING, - transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING, - ] - for lookup_location in lookup_locations: - for custom_class in lookup_location._extra_content.values(): - if isinstance(custom_class, tuple): - for custom_subclass in custom_class: - if custom_subclass is not None and custom_subclass.__name__ == module_name: - return custom_subclass - elif custom_class is not None and custom_class.__name__ == module_name: - return custom_class + + if "ImageProcess" in module_name: + sub_path = os.path.abspath(os.path.dirname(__file__)) + sub_path = str(Path(sub_path).parent) + sys.path.insert(0, sub_path) + mindone_transformers_module = importlib.import_module("mindone.transformers") + if not hasattr(mindone_transformers_module, module_name): + raise ValueError(f"Expect to have `{module_name}` registered in `mindone.transformers`, but failed to load it!") + return getattr(mindone_transformers_module, module_name) else: - raise ValueError( - f"Could not find module {module_name} in `transformers`. If this is a custom class, " - f"it should be registered using the relevant `AutoClass.register()` function so that " - f"other functions can find it!" - ) + if hasattr(transformers_module, module_name) + return getattr(transformers_module, module_name) + else: + raise ValueError(f"Expect to have `{module_name}` registered in `transformers`, but failed to load it!") @property def model_input_names(self): From 2cb578b6bbbec9a18dfd73933bb27cba7280e21c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 25 Aug 2025 13:04:51 +0800 Subject: [PATCH 47/94] not support FA --- .../modeling_flash_attention_utils.py | 138 +----------------- 1 file changed, 1 insertion(+), 137 deletions(-) diff --git a/mindone/transformers/modeling_flash_attention_utils.py b/mindone/transformers/modeling_flash_attention_utils.py index 49d58197de..ae283d445c 100644 --- a/mindone/transformers/modeling_flash_attention_utils.py +++ b/mindone/transformers/modeling_flash_attention_utils.py @@ -265,50 +265,6 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[ms.Type] = None): return q, k, v -# TODO: fix this flash attention 2 and 3 -def _lazy_imports(impl: Optional[str]): - # returns funcs and pad/unpad based on impl - is_fa2 = is_flash_attn_2_available() or is_torch_npu_available() - is_fa3 = is_flash_attn_3_available() - if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): - try: - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import pad_input, unpad_input - - return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False - - except ImportError as e: - if not globals().get("use_remote_fa2", None): - use_remote_fa2 = ( - input( - "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? " - ) - .strip() - .lower() - ) - globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"} - if globals()["use_remote_fa2"]: - raise NotImplementedError("Remote flash attention 2 is not supported yet.") - else: - raise ImportError( - "Failed to import flash attention 2, please install it or use another implementation." - ) from e - if impl == "flash_attention_3" or (impl is None and is_fa3): - from flash_attn_interface import flash_attn_func, flash_attn_varlen_func - - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input - return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True - else: - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input - return ( - getattr(impl, "flash_attn_func", None), - getattr(impl, "flash_attn_varlen_func"), - pad_input, - unpad_input, - True, - ) - - _flash_supports_window = None @@ -348,96 +304,4 @@ def _flash_attention_forward( implementation: Optional[str] = None, **kwargs, ): - if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")): - flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) - globals()["_flash_fn"] = flash_fn - globals()["_flash_varlen_fn"] = flash_varlen_fn - globals()["_pad_fn"] = pad_fn - globals()["_unpad_fn"] = unpad_fn - globals()["_is_fa3"] = is_fa3 - flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters - globals()["_flash_supports_window"] = flash_supports_window - else: - flash_fn = globals()["_flash_fn"] - flash_varlen_fn = globals()["_flash_varlen_fn"] - pad_fn = globals()["_pad_fn"] - unpad_fn = globals()["_unpad_fn"] - is_fa3 = globals()["_is_fa3"] - flash_supports_window = globals()["_flash_supports_window"] - - causal = is_causal and not (use_top_left_mask and query_length == 1) - use_sw = ( - (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window - ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} - if not is_fa3: - flash_kwargs["dropout_p"] = dropout - # if is_flash_attn_greater_or_equal("2.4.1"): - # det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - # flash_kwargs["deterministic"] = det - if softcap is not None: - flash_kwargs["softcap"] = softcap - - query_states, key_states, value_states = fa_peft_integration_check( - query_states, key_states, value_states, target_dtype - ) - use_mask = position_ids is not None or all( - k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k] - ) - if attention_mask is not None: - q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( - query_states, key_states, value_states, attention_mask, query_length, unpad_fn - ) - - out_unpad = flash_varlen_fn( - q, - k, - v, - cu_seqlens_q=cu_q.to(ms.int32), - cu_seqlens_k=cu_k.to(ms.int32), - max_seqlen_q=mq, - max_seqlen_k=mk, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, - ) - if isinstance(out_unpad, tuple): - out_unpad = out_unpad[0] - out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) - elif use_mask: - if cu_seq_lens_q is None or cu_seq_lens_k is None: - if position_ids is None: - raise ValueError( - "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed." - ) - q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids( - query_states, key_states, value_states, position_ids - ) - else: - q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) - k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) - v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) - mq, mk = max_length_q, max_length_k - cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k - - out = flash_varlen_fn( - q, - k, - v, - cu_seqlens_q=cu_q.to(ms.int32), - cu_seqlens_k=cu_k.to(ms.int32), - max_seqlen_q=mq, - max_seqlen_k=mk, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, - ) - if isinstance(out, tuple): - out = out[0] - out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) - else: - out = flash_fn( - query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs - ) - - return out[0] if isinstance(out, tuple) else out + raise NotImplementedError("`_flash_attention_forward` is not supported yet. Use `mindone.transformers.integrations.flash_attention.flash_attention_forward instead!`") From 9457ebce7cbafb72b91c91813bb074286b54ac8e Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:58:41 +0800 Subject: [PATCH 48/94] add video_processing_utils --- .../transformers/video_processing_utils.py | 860 ++++++++++++++++++ 1 file changed, 860 insertions(+) create mode 100644 mindone/transformers/video_processing_utils.py diff --git a/mindone/transformers/video_processing_utils.py b/mindone/transformers/video_processing_utils.py new file mode 100644 index 0000000000..926719c5ce --- /dev/null +++ b/mindone/transformers/video_processing_utils.py @@ -0,0 +1,860 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import warnings +from copy import deepcopy +from typing import Any, Optional, Union + +import numpy as np +from transformers.dynamic_module_utils import custom_object_save +from transformers.utils import ( + VIDEO_PROCESSOR_NAME, + add_start_docstrings, + cached_file, + copy_func, + download_url, + is_offline_mode, + is_remote_url, +) + +from .image_processing_utils import BatchFeature, get_size_dict +from .image_processing_utils_fast import BaseImageProcessorFast +from .image_utils import ChannelDimension, SizeDict, validate_kwargs +from .processing_utils import Unpack, VideosKwargs +from .utils import TensorType, is_mindspore_available, is_vision_available, logging +from .video_utils import ( + VideoInput, + VideoMetadata, + group_videos_by_shape, + load_video, + make_batched_videos, + reorder_videos, + to_channel_dimension_format, +) + +if is_vision_available(): + from .image_utils import PILImageResampling + +if is_mindspore_available(): + import mindspore as ms + from mindspore import mint + + from .image_utils import pil_torch_interpolation_mapping + + +logger = logging.get_logger(__name__) + + +BASE_VIDEO_PROCESSOR_DOCSTRING = r""" + Args: + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the video's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `self.size`): + Size of the output video after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + size_divisor (`int`, *optional*, defaults to `self.size_divisor`): + The size by which to make sure both the height and width can be divided. + default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): + Whether to default to a square video when resizing, if size is an int. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the video. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the video to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + do_pad (`bool`, *optional*): + Whether to pad the video to the `(max_height, max_width)` of the videos in the batch. + crop_size (`dict[str, int]` *optional*, defaults to `self.crop_size`): + Size of the output video after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the video by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the video. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the video. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the video. This is a float or list of floats the length of the number of + channels in the video. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the video. This is a float or list of floats the length of the + number of channels in the video. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`): + Whether to convert the video to RGB. + video_metadata (`VideoMetadata`, *optional*): + Metadata of the video containing information about total duration, fps and total number of frames. + do_sample_frames (`int`, *optional*, defaults to `self.do_sample_frames`): + Whether to sample frames from the video before processing or to process the whole video. + num_frames (`int`, *optional*, defaults to `self.num_frames`): + Maximum number of frames to sample when `do_sample_frames=True`. + fps (`int` or `float`, *optional*, defaults to `self.fps`): + Target frames to sample per second when `do_sample_frames=True`. + return_tensors (`str` or `TensorType`, *optional*): + Returns stacked tensors if set to `pt, otherwise returns a list of tensors. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output video. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input video. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input video. If unset, the channel dimension format is inferred + from the input video. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: video in (height, width) format. + """ + + +@add_start_docstrings( + "Constructs a base VideoProcessor.", + BASE_VIDEO_PROCESSOR_DOCSTRING, +) +class BaseVideoProcessor(BaseImageProcessorFast): + _auto_class = None + + resample = None + image_mean = None + image_std = None + size = None + size_divisor = None + default_to_square = True + crop_size = None + do_resize = None + do_center_crop = None + do_pad = None + do_rescale = None + rescale_factor = 1 / 255 + do_normalize = None + do_convert_rgb = None + do_sample_frames = None + fps = None + num_frames = None + video_metadata = None + valid_kwargs = VideosKwargs + model_input_names = ["pixel_values_videos"] + + def __init__(self, **kwargs: Unpack[VideosKwargs]) -> None: + super().__init__() + + self._processor_class = kwargs.pop("processor_class", None) + + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + # Prepare size related keys and turn then into `SizeDict` + size = kwargs.pop("size", self.size) + self.size = ( + get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square)) + if size is not None + else None + ) + crop_size = kwargs.pop("crop_size", self.crop_size) + self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + + # Save valid kwargs in a list for further processing + self.model_valid_processing_keys = list(self.valid_kwargs.__annotations__.keys()) + for key in self.model_valid_processing_keys: + if kwargs.get(key) is not None: + setattr(self, key, kwargs[key]) + else: + setattr(self, key, deepcopy(getattr(self, key, None))) + + def __call__(self, videos, **kwargs) -> BatchFeature: + return self.preprocess(videos, **kwargs) + + def convert_to_rgb( + self, + video: "ms.Tensor", + ) -> VideoInput: + """ + Converts a video to RGB format. + + Args: + video (`"ms.Tensor"`): + The video to convert. + + Returns: + `ms.Tensor`: The converted video. + """ + + video = ms.dataset.vision.c_transforms.ConvertColor(video) + if video.shape[-3] == 3 or not (video[..., 3, :, :] < 255).any(): + return video + + # There is a transparency layer, blend it with a white background. + # Calculate the alpha proportion for blending. + alpha = video[..., 3, :, :] / 255.0 + video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., :3, :, :] + return video + + def sample_frames( + self, + video: "ms.Tensor", + metadata: Optional[Union[VideoMetadata, dict]] = None, + num_frames: Optional[int] = None, + fps: Optional[Union[int, float]] = None, + ): + """ + Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames. + If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames` + and `fps` are mutually exclusive. + + Args: + video (`ms.Tensor`): + Video that need to be sampled. + metadata (`VideoMetadata`, *optional*): + Metadata of the video containing information about total duration, fps and total number of frames. + num_frames (`int`, *optional*): + Maximum number of frames to sample. Defaults to `self.num_frames`. + fps (`int` or `float`, *optional*): + Target frames to sample per second. Defaults to `self.fps`. + + Returns: + ms.Tensor: + Sampled video frames. + """ + if fps is not None and num_frames is not None: + raise ValueError( + "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!" + ) + + num_frames = num_frames if num_frames is not None else self.num_frames + fps = fps if fps is not None else self.fps + total_num_frames = video.shape[0] + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is None and fps is not None: + if metadata is None: + raise ValueError( + "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " + "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video" + ) + num_frames = int(total_num_frames / metadata["fps"] * fps) + + if num_frames > total_num_frames: + raise ValueError( + f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. " + ) + + if num_frames is not None: + indices = mint.arange(0, total_num_frames, total_num_frames / num_frames).int() + else: + indices = mint.arange(0, total_num_frames).int() + + video = video[indices].contiguous() + return video + + def _prepare_input_videos( + self, + videos: VideoInput, + video_metadata: VideoMetadata = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> list["ms.Tensor"]: + """ + Prepare the input videos for processing. + """ + videos = make_batched_videos(videos) + if video_metadata is not None: + batch_metadata = [metadata for batch_list in video_metadata for metadata in batch_list] + else: + batch_metadata = [None] * len(videos) + + processed_videos = [] + for video in videos: + # `make_batched_videos` always returns a 4D array per video + if isinstance(video, np.ndarray): + video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_data_format) + # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays + video = ms.tensor(video).contiguous() + + processed_videos.append(video) + return processed_videos, batch_metadata + + @add_start_docstrings(BASE_VIDEO_PROCESSOR_DOCSTRING) + def preprocess( + self, + videos: VideoInput, + **kwargs: Unpack[VideosKwargs], + ) -> BatchFeature: + validate_kwargs( + captured_kwargs=kwargs.keys(), + valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"], + ) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + input_data_format = kwargs.pop("input_data_format") + video_metadata = kwargs.pop("video_metadata") + videos, video_metadata = self._prepare_input_videos( + videos=videos, video_metadata=video_metadata, input_data_format=input_data_format + ) + + kwargs = self._further_process_kwargs(**kwargs) + self._validate_preprocess_kwargs(**kwargs) + + # torch resize uses interpolation instead of resample + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + # Pop kwargs that are not needed in _preprocess + kwargs.pop("default_to_square") + kwargs.pop("data_format") + + return self._preprocess(videos=videos, video_metadata=video_metadata, **kwargs) + + def _preprocess( + self, + videos: list["ms.Tensor"], + video_metadata: Union[list[VideoMetadata], list[dict]], + do_convert_rgb: bool, + do_resize: bool, + size: SizeDict, + size_divisor: Optional[int], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + do_pad: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + do_sample_frames: Optional[bool] = None, + fps: Optional[Union[int, float]] = None, + num_frames: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if do_sample_frames: + # Sample video frames + videos = [ + self.sample_frames(video, metadata=metadata, num_frames=num_frames, fps=fps) + for video, metadata in zip(videos, video_metadata) + ] + + # Group videos by size for batched resizing + grouped_videos, grouped_videos_index = group_videos_by_shape(videos) + resized_videos_grouped = {} + for shape, stacked_videos in grouped_videos.items(): + if do_convert_rgb: + stacked_videos = self.convert_to_rgb(stacked_videos) + if do_resize: + stacked_videos = self.resize( + stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation + ) + resized_videos_grouped[shape] = stacked_videos + resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) + + # Group videos by size for further processing + # Needed in case do_resize is False, or resize returns videos with different sizes + grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos) + processed_videos_grouped = {} + for shape, stacked_videos in grouped_videos.items(): + if do_center_crop: + stacked_videos = self.center_crop(stacked_videos, crop_size) + # Fused rescale and normalize + stacked_videos = self.rescale_and_normalize( + stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_videos_grouped[shape] = stacked_videos + + processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) + processed_videos = mint.stack(processed_videos, dim=0) if return_tensors else processed_videos + + return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + r""" + Instantiate a type of [`~video_processing_utils.VideoProcessorBase`] from an video processor. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained video hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a video processor file saved using the + [`~video_processing_utils.VideoProcessorBase.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved video processor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model video processor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the video processor files and override the cached versions if + they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final video processor object. If `True`, then this + functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of + `kwargs` which has not been used to update `video_processor` and is otherwise ignored. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are video processor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + Returns: + A video processor of type [`~video_processing_utils.ImagVideoProcessorBase`]. + + Examples: + + ```python + # We can't instantiate directly the base class *VideoProcessorBase* so let's show the examples on a + # derived class: *LlavaOnevisionVideoProcessor* + video_processor = LlavaOnevisionVideoProcessor.from_pretrained( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ) # Download video_processing_config from huggingface.co and cache. + video_processor = LlavaOnevisionVideoProcessor.from_pretrained( + "./test/saved_model/" + ) # E.g. video processor (or model) was saved using *save_pretrained('./test/saved_model/')* + video_processor = LlavaOnevisionVideoProcessor.from_pretrained("./test/saved_model/preprocessor_config.json") + video_processor = LlavaOnevisionVideoProcessor.from_pretrained( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False + ) + assert video_processor.do_normalize is False + video_processor, unused_kwargs = LlavaOnevisionVideoProcessor.from_pretrained( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False, return_unused_kwargs=True + ) + assert video_processor.do_normalize is False + assert unused_kwargs == {"foo": False} + ```""" + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + video_processor_dict, kwargs = cls.get_video_processor_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_dict(video_processor_dict, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save an video processor object to the directory `save_directory`, so that it can be re-loaded using the + [`~video_processing_utils.VideoProcessorBase.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the video processor JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token", None) is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + + # If we save using the predefined names, we can load using `from_pretrained` + output_video_processor_file = os.path.join(save_directory, VIDEO_PROCESSOR_NAME) + + self.to_json_file(output_video_processor_file) + logger.info(f"Video processor saved in {output_video_processor_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return [output_video_processor_file] + + @classmethod + def get_video_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + video processor of type [`~video_processing_utils.VideoProcessorBase`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + + Returns: + `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the video processor object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + user_agent = {"file_type": "video processor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + resolved_video_processor_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + video_processor_file = pretrained_model_name_or_path + resolved_video_processor_file = download_url(pretrained_model_name_or_path) + else: + try: + # Try to load with a new config name first and if not successfull try with + # the old file name. In case we can load with old name only, raise a deprecation warning + # Deprecated until v5.0 + video_processor_file = VIDEO_PROCESSOR_NAME + resolved_video_processor_file = cached_file( + pretrained_model_name_or_path, + video_processor_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + ) + except OSError: + video_processor_file = "preprocessor_config.json" + resolved_video_processor_file = cached_file( + pretrained_model_name_or_path, + video_processor_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + ) + logger.warning_once( + "You have video processor config saved in `preprocessor.json` file which is deprecated. " + "Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename " + "the file or load and save the processor back which renames it automatically. " + "Loading from `preprocessor.json` will be removed in v5.0." + ) + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {VIDEO_PROCESSOR_NAME} file" + ) + + try: + # Load video_processor dict + with open(resolved_video_processor_file, "r", encoding="utf-8") as reader: + text = reader.read() + video_processor_dict = json.loads(text) + + except json.JSONDecodeError: + raise OSError( + f"It looks like the config file at '{resolved_video_processor_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_video_processor_file}") + else: + logger.info( + f"loading configuration file {video_processor_file} from cache at {resolved_video_processor_file}" + ) + return video_processor_dict, kwargs + + @classmethod + def from_dict(cls, video_processor_dict: dict[str, Any], **kwargs): + """ + Instantiates a type of [`~video_processing_utils.VideoProcessorBase`] from a Python dictionary of parameters. + + Args: + video_processor_dict (`dict[str, Any]`): + Dictionary that will be used to instantiate the video processor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + [`~video_processing_utils.VideoProcessorBase.to_dict`] method. + kwargs (`dict[str, Any]`): + Additional parameters from which to initialize the video processor object. + + Returns: + [`~video_processing_utils.VideoProcessorBase`]: The video processor object instantiated from those + parameters. + """ + video_processor_dict = video_processor_dict.copy() + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + # The `size` parameter is a dict and was previously an int or tuple in feature extractors. + # We set `size` here directly to the `video_processor_dict` so that it is converted to the appropriate + # dict within the video processor and isn't overwritten if `size` is passed in as a kwarg. + if "size" in kwargs and "size" in video_processor_dict: + video_processor_dict["size"] = kwargs.pop("size") + if "crop_size" in kwargs and "crop_size" in video_processor_dict: + video_processor_dict["crop_size"] = kwargs.pop("crop_size") + + video_processor = cls(**video_processor_dict) + + # Update video_processor with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(video_processor, key): + setattr(video_processor, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Video processor {video_processor}") + if return_unused_kwargs: + return video_processor, kwargs + else: + return video_processor + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance. + """ + output = deepcopy(self.__dict__) + output.pop("model_valid_processing_keys", None) + output.pop("_valid_kwargs_names", None) + output["video_processor_type"] = self.__class__.__name__ + + return output + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. + """ + dictionary = self.to_dict() + + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + + # make sure private name "_processor_class" is correctly + # saved as "processor_class" + _processor_class = dictionary.pop("_processor_class", None) + if _processor_class is not None: + dictionary["processor_class"] = _processor_class + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this image_processor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]): + """ + Instantiates a video processor of type [`~video_processing_utils.VideoProcessorBase`] from the path to a JSON + file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + A video processor of type [`~video_processing_utils.VideoProcessorBase`]: The video_processor object + instantiated from that JSON file. + """ + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + video_processor_dict = json.loads(text) + return cls(**video_processor_dict) + + @classmethod + def register_for_auto_class(cls, auto_class="AutoVideoProcessor"): + """ + Register this class with a given auto class. This should only be used for custom video processors as the ones + in the library are already mapped with `AutoVideoProcessor `. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoVideoProcessor "`): + The auto class to register this new video processor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def fetch_videos(self, video_url_or_urls: Union[str, list[str]]): + """ + Convert a single or a list of urls into the corresponding `np.array` objects. + + If a single url is passed, the return value will be a single object. If a list is passed a list of objects is + returned. + """ + if isinstance(video_url_or_urls, list): + return [self.fetch_videos(x) for x in video_url_or_urls] + elif isinstance(video_url_or_urls, str): + return load_video(video_url_or_urls) + else: + raise TypeError(f"only a single or a list of entries is supported but got type={type(video_url_or_urls)}") + + +BaseVideoProcessor.push_to_hub = copy_func(BaseVideoProcessor.push_to_hub) +if BaseVideoProcessor.push_to_hub.__doc__ is not None: + BaseVideoProcessor.push_to_hub.__doc__ = BaseVideoProcessor.push_to_hub.__doc__.format( + object="video processor", object_class="AutoVideoProcessor", object_files="video processor file" + ) From 32031d01a839144b46b7f8a972f994f04193f6b3 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 25 Aug 2025 20:25:37 +0800 Subject: [PATCH 49/94] fix import error/add audio_utils/fix processor bug/attn_implementation check * fix generic.py error * fix generic.py error * audio_utils.py * audio_utils.py * fix errors * update processing_chameleon * update processing_idefics * update processing_llava_next * update processing_llava_next_video * update processing_llava_next_video * update processing_llava_next_video * update processing_llava_next_video * update processing_qwen_2_5_omni * update processing_siglip_fast * rm ernie 45 * sdpa does not support * sdpa does not support: aria --- mindone/transformers/audio_utils.py | 326 ++++++++++-------- .../image_processing_utils_fast.py | 2 +- mindone/transformers/masking_utils.py | 2 +- .../modeling_flash_attention_utils.py | 14 +- mindone/transformers/modeling_utils.py | 2 +- mindone/transformers/models/__init__.py | 1 - .../models/blip/image_processing_blip_fast.py | 9 +- .../models/chameleon/processing_chameleon.py | 61 +++- .../models/idefics/processing_idefics.py | 11 +- .../llava_next/processing_llava_next.py | 84 +++-- .../image_processing_llava_next_video.py | 3 +- .../processing_llava_next_video.py | 47 ++- .../processing_llava_onevision.py | 3 +- .../video_processing_llava_onevision.py | 3 +- .../qwen2_5_omni/processing_qwen2_5_omni.py | 3 +- .../siglip/image_processing_siglip_fast.py | 9 +- mindone/transformers/processing_utils.py | 2 +- mindone/transformers/utils/generic.py | 4 +- mindone/transformers/utils/import_utils.py | 8 + .../models/albert/test_modeling_albert.py | 1 + .../models/aria/test_modeling_aria.py | 1 + .../test_modeling_xlm_roberta_xl.py | 6 +- 22 files changed, 351 insertions(+), 251 deletions(-) diff --git a/mindone/transformers/audio_utils.py b/mindone/transformers/audio_utils.py index 6059e3ba42..46c5cdde7e 100644 --- a/mindone/transformers/audio_utils.py +++ b/mindone/transformers/audio_utils.py @@ -18,11 +18,178 @@ Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks and remove unnecessary dependencies. """ - +import base64 +import io +import os import warnings -from typing import List, Optional, Tuple, Union - +from io import BytesIO +from typing import Any, List, Optional, Tuple, Union +import mindspore as ms import numpy as np +import requests + +from .utils import is_librosa_available, is_numpy_array, is_soundfile_available, is_mindspore_tensor, requires_backends + +if is_soundfile_available(): + import soundfile as sf + +if is_librosa_available(): + import librosa + + # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa + import soxr + + +def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray: + """ + Loads `audio` to an np.ndarray object. + + Args: + audio (`str` or `np.ndarray`): + The audio to be loaded to the numpy array format. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate to be used when loading the audio. It should be same as the + sampling rate the model you will be using further was trained with. + timeout (`float`, *optional*): + The timeout value in seconds for the URL request. + + Returns: + `np.ndarray`: A numpy array representing the audio. + """ + requires_backends(load_audio, ["librosa"]) + + if isinstance(audio, str): + # Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav) + if audio.startswith("http://") or audio.startswith("https://"): + audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0] + elif os.path.isfile(audio): + audio = librosa.load(audio, sr=sampling_rate)[0] + elif isinstance(audio, np.ndarray): + audio = audio + else: + raise TypeError( + "Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array." + ) + return audio + + +def load_audio_as( + audio: str, + return_format: str, + timeout: Optional[int] = None, + force_mono: bool = False, + sampling_rate: Optional[int] = None, +) -> Union[str, dict[str, Any], io.BytesIO, None]: + """ + Load audio from either a local file path or URL and return in specified format. + + Args: + audio (`str`): Either a local file path or a URL to an audio file + return_format (`str`): Format to return the audio in: + - "base64": Base64 encoded string + - "dict": Dictionary with data and format + - "buffer": BytesIO object + timeout (`int`, *optional*): Timeout for URL requests in seconds + force_mono (`bool`): Whether to convert stereo audio to mono + sampling_rate (`int`, *optional*): If provided, the audio will be resampled to the specified sampling rate. + + Returns: + `Union[str, Dict[str, Any], io.BytesIO, None]`: + - `str`: Base64 encoded audio data (if return_format="base64") + - `dict`: Dictionary with 'data' (base64 encoded audio data) and 'format' keys (if return_format="dict") + - `io.BytesIO`: BytesIO object containing audio data (if return_format="buffer") + """ + # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa + requires_backends(load_audio_as, ["librosa"]) + + if return_format not in ["base64", "dict", "buffer"]: + raise ValueError(f"Invalid return_format: {return_format}. Must be 'base64', 'dict', or 'buffer'") + + try: + # Load audio bytes from URL or file + audio_bytes = None + if audio.startswith(("http://", "https://")): + response = requests.get(audio, timeout=timeout) + response.raise_for_status() + audio_bytes = response.content + elif os.path.isfile(audio): + with open(audio, "rb") as audio_file: + audio_bytes = audio_file.read() + else: + raise ValueError(f"File not found: {audio}") + + # Process audio data + with io.BytesIO(audio_bytes) as audio_file: + with sf.SoundFile(audio_file) as f: + audio_array = f.read(dtype="float32") + original_sr = f.samplerate + audio_format = f.format + if sampling_rate is not None and sampling_rate != original_sr: + # Resample audio to target sampling rate + audio_array = soxr.resample(audio_array, original_sr, sampling_rate, quality="HQ") + else: + sampling_rate = original_sr + + # Convert to mono if needed + if force_mono and audio_array.ndim != 1: + audio_array = audio_array.mean(axis=1) + + buffer = io.BytesIO() + sf.write(buffer, audio_array, sampling_rate, format=audio_format.upper()) + buffer.seek(0) + + if return_format == "buffer": + return buffer + elif return_format == "base64": + return base64.b64encode(buffer.read()).decode("utf-8") + elif return_format == "dict": + return { + "data": base64.b64encode(buffer.read()).decode("utf-8"), + "format": audio_format.lower(), + } + + except Exception as e: + raise ValueError(f"Error loading audio: {e}") + + +AudioInput = Union[ + np.ndarray, + "ms.Tensor", + list[np.ndarray], + tuple[np.ndarray], + list["ms.Tensor"], + tuple["ms.Tensor"], # noqa: F821 +] + + +def is_valid_audio(audio): + return is_numpy_array(audio) or is_mindspore_tensor(audio) + + +def is_valid_list_of_audio(audio): + return audio and all(is_valid_audio(audio_i) for audio_i in audio) + + +def make_list_of_audio( + audio: Union[list[AudioInput], AudioInput], +) -> AudioInput: + """ + Ensure that the output is a list of audio. + Args: + audio (`Union[list[AudioInput], AudioInput]`): + The input audio. + Returns: + list: A list of audio. + """ + # If it's a list of audios, it's already in the right format + if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio): + return audio + + # If it's a single audio, convert it to a list of + if is_valid_audio(audio): + return [audio] + + raise ValueError("Invalid input type. Must be a single audio or a list of audio") def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]: @@ -247,7 +414,8 @@ def mel_filter_bank( Args: num_frequency_bins (`int`): - Number of frequencies used to compute the spectrogram (should be the same as in `stft`). + Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier\ + Transform used to compute the spectrogram). num_mel_filters (`int`): Number of mel filters to generate. min_frequency (`float`): @@ -271,6 +439,12 @@ def mel_filter_bank( if norm is not None and norm != "slaney": raise ValueError('norm must be one of None or "slaney"') + if num_frequency_bins < 2: + raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2") + + if min_frequency > max_frequency: + raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}") + # center points of the triangular mel filters mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale) mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale) @@ -279,7 +453,7 @@ def mel_filter_bank( if triangularize_in_mel_space: # frequencies of FFT bins in Hz, but filters triangularized in mel space - fft_bin_width = sampling_rate / (num_frequency_bins * 2) + fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2) fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale) filter_freqs = mel_freqs else: @@ -978,145 +1152,3 @@ def amplitude_to_db_batch( spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None) return spectrogram - - -def get_mel_filter_banks( - nb_frequency_bins: int, - nb_mel_filters: int, - frequency_min: float, - frequency_max: float, - sample_rate: int, - norm: Optional[str] = None, - mel_scale: str = "htk", -) -> np.array: - warnings.warn( - "The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers", - FutureWarning, - ) - return mel_filter_bank( - num_frequency_bins=nb_frequency_bins, - num_mel_filters=nb_mel_filters, - min_frequency=frequency_min, - max_frequency=frequency_max, - sampling_rate=sample_rate, - norm=norm, - mel_scale=mel_scale, - ) - - -def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True): - """ - In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed - segments called `frames`. - - The window length (window_length) defines how much of the signal is contained in each frame, while the hop length - defines the step between the beginning of each new frame. - - - Args: - waveform (`np.array` of shape `(sample_length,)`): - The raw waveform which will be split into smaller chunks. - hop_length (`int`, *optional*, defaults to 160): - Step between each window of the waveform. - fft_window_size (`int`, *optional*, defaults to 400): - Defines the size of the window. - center (`bool`, defaults to `True`): - Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the - waveform on the left and on the right. - - Return: - framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`): - The framed waveforms that can be fed to `np.fft`. - """ - warnings.warn( - "The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers", - FutureWarning, - ) - frames = [] - for i in range(0, waveform.shape[0] + 1, hop_length): - if center: - half_window = (fft_window_size - 1) // 2 + 1 - start = i - half_window if i > half_window else 0 - end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0] - frame = waveform[start:end] - if start == 0: - padd_width = (-i + half_window, 0) - frame = np.pad(frame, pad_width=padd_width, mode="reflect") - - elif end == waveform.shape[0]: - padd_width = (0, (i - waveform.shape[0] + half_window)) - frame = np.pad(frame, pad_width=padd_width, mode="reflect") - - else: - frame = waveform[i : i + fft_window_size] - frame_width = frame.shape[0] - if frame_width < waveform.shape[0]: - frame = np.lib.pad( - frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0 - ) - frames.append(frame) - - frames = np.stack(frames, 0) - return frames - - -def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None): - """ - Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results - as `torch.stft`. - - Args: - frames (`np.array` of dimension `(num_frames, fft_window_size)`): - A framed audio signal obtained using `audio_utils.fram_wav`. - windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`: - A array representing the function that will be used to reduces the amplitude of the discontinuities at the - boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function. - For more information on the discontinuities, called *Spectral leakage*, refer to [this - tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf - fft_window_size (`int`, *optional*): - Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the - spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of - frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to - `(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally. - - Example: - - ```python - >>> from transformers.audio_utils import stft, fram_wave - >>> import numpy as np - - >>> audio = np.random.rand(50) - >>> fft_window_size = 10 - >>> hop_length = 2 - >>> framed_audio = fram_wave(audio, hop_length, fft_window_size) - >>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1)) - ``` - - Returns: - spectrogram (`np.ndarray`): - A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm - """ - warnings.warn( - "The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers", - FutureWarning, - ) - frame_size = frames.shape[1] - - if fft_window_size is None: - fft_window_size = frame_size - - if fft_window_size < frame_size: - raise ValueError("FFT size must greater or equal the frame size") - # number of FFT bins to store - nb_frequency_bins = (fft_window_size >> 1) + 1 - - spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64) - fft_signal = np.zeros(fft_window_size) - - for f, frame in enumerate(frames): - if windowing_function is not None: - np.multiply(frame, windowing_function, out=fft_signal[:frame_size]) - else: - fft_signal[:frame_size] = frame - spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins] - return spectrogram.T diff --git a/mindone/transformers/image_processing_utils_fast.py b/mindone/transformers/image_processing_utils_fast.py index 2e3a5e19fa..c43e5c9189 100644 --- a/mindone/transformers/image_processing_utils_fast.py +++ b/mindone/transformers/image_processing_utils_fast.py @@ -54,7 +54,7 @@ if is_mindspore_available(): import mindspore as ms - import mindspore.mint.functional as F + import mindspore.mint.nn.functional as F from mindspore import mint from mindspore.dataset import vision from mindspore.dataset.vision import Inter as InterpolationMode diff --git a/mindone/transformers/masking_utils.py b/mindone/transformers/masking_utils.py index fa30c6256c..e7329a447b 100644 --- a/mindone/transformers/masking_utils.py +++ b/mindone/transformers/masking_utils.py @@ -21,7 +21,7 @@ from transformers.configuration_utils import PretrainedConfig import mindspore as ms -import mindspore.mint.functional as F +import mindspore.mint.nn.functional as F from mindspore import mint from .cache_utils import Cache diff --git a/mindone/transformers/modeling_flash_attention_utils.py b/mindone/transformers/modeling_flash_attention_utils.py index ae283d445c..8d79eea2fc 100644 --- a/mindone/transformers/modeling_flash_attention_utils.py +++ b/mindone/transformers/modeling_flash_attention_utils.py @@ -14,15 +14,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import os + import warnings from typing import Optional, TypedDict from transformers.utils import logging import mindspore as ms -import mindspore.mint.functional as F +import mindspore.mint.nn.functional as F from mindspore import mint logger = logging.get_logger(__name__) @@ -268,11 +267,6 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[ms.Type] = None): _flash_supports_window = None -def is_flash_attn_available(): - # return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() - return True - - def flash_attn_supports_top_left_mask(): raise NotImplementedError("flash_attn_supports_top_left_mask is not supported yet.") @@ -304,4 +298,6 @@ def _flash_attention_forward( implementation: Optional[str] = None, **kwargs, ): - raise NotImplementedError("`_flash_attention_forward` is not supported yet. Use `mindone.transformers.integrations.flash_attention.flash_attention_forward instead!`") + raise NotImplementedError( + "`_flash_attention_forward` is not supported yet. Use `mindone.transformers.integrations.flash_attention.flash_attention_forward instead!`" + ) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index bd52689553..e1fa8b7b1d 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -1294,7 +1294,7 @@ def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool: 'Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' ) if not is_sdpa_available(): - raise ImportError("MindSpore SDPA requirements in Transformers are not met.") + raise ImportError("MindSpore SDPA requirements in Transformers are not met. Use `attn_implementation='eager'` instead.") return True diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 8aa4b72e23..0400abc7e3 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -39,7 +39,6 @@ depth_anything, dinov2, dpt, - ernie4_5, fuyu, gemma, gemma2, diff --git a/mindone/transformers/models/blip/image_processing_blip_fast.py b/mindone/transformers/models/blip/image_processing_blip_fast.py index 9f2811baf3..5d3c4ac390 100644 --- a/mindone/transformers/models/blip/image_processing_blip_fast.py +++ b/mindone/transformers/models/blip/image_processing_blip_fast.py @@ -17,16 +17,13 @@ # limitations under the License. """Fast Image processor class for BLIP.""" -from transformers.utils import add_start_docstrings +from transformers.utils import auto_docstring -from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast +from ...image_processing_utils_fast import BaseImageProcessorFast from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling -@add_start_docstrings( - "Constructs a fast BLIP image processor.", - BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, -) +@auto_docstring class BlipImageProcessorFast(BaseImageProcessorFast): # To be checked against the slow image processor # None values left after checking can be removed diff --git a/mindone/transformers/models/chameleon/processing_chameleon.py b/mindone/transformers/models/chameleon/processing_chameleon.py index 120afed1f1..cc95fdf408 100644 --- a/mindone/transformers/models/chameleon/processing_chameleon.py +++ b/mindone/transformers/models/chameleon/processing_chameleon.py @@ -20,13 +20,13 @@ """ from typing import List, Optional, Union +import numpy as np from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -import mindspore as ms from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack class ChameleonTextKwargs(TextKwargs, total=False): @@ -39,6 +39,7 @@ class ChameleonProcessorKwargs(ProcessingKwargs, total=False): "text_kwargs": { "padding": False, "return_for_text_completion": False, + "return_mm_token_type_ids": False, }, "common_kwargs": { "return_tensors": "ms", @@ -67,16 +68,21 @@ class ChameleonProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") - valid_kwargs = ["image_seq_length", "image_token"] + image_processor_class = "ChameleonImageProcessor" def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): self.image_seq_length = image_seq_length self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) self.image_start_token = ( tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "" ) # fixed tokens for start and end, so can hardcode self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "" + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token) + self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token) + self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id] super().__init__(image_processor, tokenizer) @@ -92,7 +98,7 @@ def __call__( Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to - CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring of the above two methods for more information. Args: @@ -120,8 +126,7 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) + if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): @@ -145,15 +150,45 @@ def __call__( sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode prompt_strings.append(sample) - output_kwargs["text_kwargs"].pop("return_tensors", None) - data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors="np") - for k, v in data.items(): - data[k] = ms.tensor(v) - + image_inputs = {} if images is not None: - data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + # add 2 for BOI and EOI tokens + num_image_tokens = [self.image_seq_length + 2] * len(image_sizes) + num_image_patches = [1] * len(image_sizes) + + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) - return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]) + return MultiModalData(**vision_data) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/mindone/transformers/models/idefics/processing_idefics.py b/mindone/transformers/models/idefics/processing_idefics.py index ec9ecb3b59..c2bf332325 100644 --- a/mindone/transformers/models/idefics/processing_idefics.py +++ b/mindone/transformers/models/idefics/processing_idefics.py @@ -30,14 +30,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ( - ImagesKwargs, - ProcessingKwargs, - ProcessorMixin, - TextKwargs, - Unpack, - _validate_images_text_input_order, -) +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack IMAGE_TOKEN = "" @@ -297,8 +290,6 @@ def __call__( """ if images is None and text is None: raise ValueError("You need to specify either `text` or `images` and `text`.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) if images is None: # assuming the user wants to use the old behavior with prompts as the only argument diff --git a/mindone/transformers/models/llava_next/processing_llava_next.py b/mindone/transformers/models/llava_next/processing_llava_next.py index de1fc14339..8d67f00419 100644 --- a/mindone/transformers/models/llava_next/processing_llava_next.py +++ b/mindone/transformers/models/llava_next/processing_llava_next.py @@ -22,14 +22,13 @@ from typing import List, Union +import numpy as np from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -import mindspore as ms - from ...feature_extraction_utils import BatchFeature from ...image_processing_utils import select_best_resolution from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...utils import logging logger = logging.get_logger(__name__) @@ -39,6 +38,7 @@ class LlavaNextProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, + "return_mm_token_type_ids": False, }, "images_kwargs": { "do_pad": True, @@ -62,7 +62,7 @@ class LlavaNextProcessor(ProcessorMixin): Patch size from the vision tower. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. - Shoudl be same as in model's config + Should be same as in model's config chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. image_token (`str`, *optional*, defaults to `""`): @@ -73,13 +73,6 @@ class LlavaNextProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "patch_size", - "vision_feature_select_strategy", - "image_token", - "num_additional_image_tokens", - ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -98,6 +91,11 @@ def __init__( self.num_additional_image_tokens = num_additional_image_tokens self.vision_feature_select_strategy = vision_feature_select_strategy self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( @@ -112,7 +110,7 @@ def __call__( Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to - LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring of the above two methods for more information. Args: @@ -135,8 +133,6 @@ def __call__( """ if images is None and text is None: raise ValueError("You have to specify at least images or text.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( LlavaNextProcessorKwargs, @@ -151,7 +147,7 @@ def __call__( if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string, or a list of strings") + raise TypeError("Invalid input text. Please provide a string, or a list of strings") prompt_strings = text if image_inputs: @@ -172,12 +168,18 @@ def __call__( prompt_strings.append(sample) prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] - output_kwargs["text_kwargs"].pop("return_tensors", None) - text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors="np") - for k, v in text_inputs.items(): - text_inputs[k] = ms.tensor(v) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) - return BatchFeature(data={**text_inputs, **image_inputs}) + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: image_grid_pinpoints = self.image_processor.image_grid_pinpoints @@ -221,6 +223,48 @@ def _get_unpadded_features(self, height, width, patches_height, patches_width, s newline_features = current_height return (unpadded_features, newline_features) + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (list[list[str]], *optional*): + The input sizes formatted as (height, width) per each image. + video_sizes (list[list[str]], *optional*): + The input sizes formatted as (num_frames, height, width) per each video. + audio_lengths (list[int], *optional*): + The input length formatted as per each audio. + Returns: + dict[str, list[int]]: A dictionary mapping each modality ("image", "video", "audio") + to a list containing the number of placeholder tokens required. If the model doesn't accept + a certain modality or no input sizes are provided, the dict value is set to an empty list. + """ + vision_data = {} + if image_sizes is not None: + images_kwargs = LlavaNextProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + size = images_kwargs.get("size", None) or self.image_processor.size + size = ( + (size["shortest_edge"], size["shortest_edge"]) + if "shortest_edge" in size + else (min(size["height"], size["width"]), min(size["height"], size["width"])) + ) + processed_height, processed_width = size + + batch_num_image_tokens = [] + num_image_patches = [1] * len(image_sizes) # llava-next doesn't batch pixels as Idefics, thus `1` patch` + for image_size in image_sizes: + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features( + orig_height, orig_width, processed_height, processed_width + ) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + batch_num_image_tokens.append(num_image_tokens) + vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): """ diff --git a/mindone/transformers/models/llava_next_video/image_processing_llava_next_video.py b/mindone/transformers/models/llava_next_video/image_processing_llava_next_video.py index 3981c8944f..91b7af98e3 100644 --- a/mindone/transformers/models/llava_next_video/image_processing_llava_next_video.py +++ b/mindone/transformers/models/llava_next_video/image_processing_llava_next_video.py @@ -30,15 +30,14 @@ ChannelDimension, ImageInput, PILImageResampling, - VideoInput, infer_channel_dimension_format, is_scaled_image, - make_batched_videos, make_list_of_images, to_numpy_array, validate_preprocess_arguments, ) from ...utils import TensorType, logging +from ...video_utils import VideoInput, make_batched_videos logger = logging.get_logger(__name__) diff --git a/mindone/transformers/models/llava_next_video/processing_llava_next_video.py b/mindone/transformers/models/llava_next_video/processing_llava_next_video.py index 417a1bcded..d4fb66609f 100644 --- a/mindone/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/mindone/transformers/models/llava_next_video/processing_llava_next_video.py @@ -25,13 +25,12 @@ import numpy as np from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -import mindspore as ms - from ...feature_extraction_utils import BatchFeature from ...image_processing_utils import select_best_resolution -from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...image_utils import ImageInput, get_image_size, to_numpy_array +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...utils import logging +from ...video_utils import VideoInput logger = logging.get_logger(__name__) @@ -57,7 +56,7 @@ class LlavaNextVideoProcessor(ProcessorMixin): [`LlamaTokenizerFast`]. See the [`~LlavaNextVideoProcessor.__call__`] and [`~LlavaNextVideoProcessor.decode`] for more information. Args: - video_processor ([`LlavaNextVideoImageProcessor`], *optional*): + video_processor ([`LlavaNextVideoVideoProcessor`], *optional*): The video processor is a required input. image_processor ([`LlavaNextImageProcessor`], *optional*): The image processor is a required input. @@ -69,7 +68,7 @@ class LlavaNextVideoProcessor(ProcessorMixin): Patch size from the vision tower. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. - Shoudl be same as in model's config + Should be same as in model's config video_token (`str`, *optional*, defaults to `"