diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 2cda1be9ad..d8d582cc1e 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -1241,6 +1241,12 @@ Siglip2TextModel, Siglip2VisionModel, ) +from .models.smollm3 import ( + SmolLM3ForCausalLM, + SmolLM3ForQuestionAnswering, + SmolLM3ForSequenceClassification, + SmolLM3ForTokenClassification, +) from .models.smolvlm import ( SmolVLMForConditionalGeneration, SmolVLMModel, diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index b1d25744a3..19dedc9a20 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -237,6 +237,7 @@ ("siglip", "SiglipConfig"), ("siglip2", "Siglip2Config"), ("siglip_vision_model", "SiglipVisionConfig"), + ("smollm3", "SmolLM3Config"), ("smolvlm", "SmolVLMConfig"), ("smolvlm_vision", "SmolVLMVisionConfig"), ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"), @@ -508,6 +509,7 @@ ("siglip2", "SigLIP2"), ("siglip2_vision_model", "Siglip2VisionModel"), ("siglip_vision_model", "SiglipVisionModel"), + ("smollm3", "SmolLM3"), ("smolvlm", "SmolVLM"), ("smolvlm_vision", "SmolVLMVisionTransformer"), ("speech-encoder-decoder", "Speech Encoder decoder"), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index c89e50bc47..00d2816f86 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -216,6 +216,7 @@ ("siglip", "SiglipModel"), ("siglip2", "Siglip2Model"), ("siglip_vision_model", "SiglipVisionModel"), + ("smollm3", "SmolLM3Model"), ("smolvlm", "SmolVLMModel"), ("smolvlm_vision", "SmolVLMVisionTransformer"), ("speech_to_text", "Speech2TextModel"), @@ -318,6 +319,7 @@ ("roc_bert", "RoCBertForPreTraining"), ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), ("rwkv", "RwkvForCausalLM"), + ("smollm3", "SmolLM3ForCausalLM"), ("squeezebert", "SqueezeBertForMaskedLM"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), @@ -915,6 +917,7 @@ ("roc_bert", "RoCBertForSequenceClassification"), ("roberta", "RobertaForSequenceClassification"), ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"), + ("smollm3", "SmolLM3ForSequenceClassification"), ("stablelm", "StableLmForSequenceClassification"), ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), @@ -1066,6 +1069,7 @@ ("roberta", "RobertaForTokenClassification"), ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), ("roformer", "RoFormerForTokenClassification"), + ("smollm3", "SmolLM3ForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), ("squeezebert", "SqueezeBertForTokenClassification"), ("stablelm", "StableLmForTokenClassification"), diff --git a/mindone/transformers/models/smollm3/__init__.py b/mindone/transformers/models/smollm3/__init__.py new file mode 100644 index 0000000000..47c214b451 --- /dev/null +++ b/mindone/transformers/models/smollm3/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_smollm3 import * diff --git a/mindone/transformers/models/smollm3/modeling_smollm3.py b/mindone/transformers/models/smollm3/modeling_smollm3.py new file mode 100644 index 0000000000..aaa18574fb --- /dev/null +++ b/mindone/transformers/models/smollm3/modeling_smollm3.py @@ -0,0 +1,519 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/smollm3/modular_smollm3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_smollm3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +from transformers.models.smollm3.configuration_smollm3 import SmolLM3Config +from transformers.utils.deprecation import deprecate_kwarg + +import mindspore as ms +from mindspore import mint, nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, can_return_tuple +from ...utils.generic import check_model_inputs + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mint.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`ms.Tensor`): The query tensor. + k (`ms.Tensor`): The key tensor. + cos (`ms.Tensor`): The cosine part of the rotary embedding. + sin (`ms.Tensor`): The sine part of the rotary embedding. + position_ids (`ms.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim)) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mint.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SmolLM3Attention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: SmolLM3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + self.use_rope = config.no_rope_layers[layer_idx] + self.sliding_window = ( + config.sliding_window + if config.use_sliding_window and config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def construct( + self, + hidden_states: ms.Tensor, + position_embeddings: tuple[ms.Tensor, ms.Tensor], + attention_mask: Optional[ms.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[ms.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ms.Tensor, Optional[ms.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + + if self.use_rope: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class SmolLM3RMSNorm(nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + SmolLM3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = ms.Parameter(mint.ones(hidden_size)) + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class SmolLM3MLP(nn.Cell): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def construct(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class SmolLM3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SmolLM3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = SmolLM3Attention(config=config, layer_idx=layer_idx) + + self.mlp = SmolLM3MLP(config) + self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> ms.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class SmolLM3PreTrainedModel(PreTrainedModel): + config: SmolLM3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SmolLM3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": SmolLM3DecoderLayer, + "attentions": SmolLM3Attention, + } + + +class SmolLM3RotaryEmbedding(nn.Cell): + inv_freq: ms.Tensor # fix linting for `register_buffer` + + def __init__(self, config: SmolLM3Config): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def construct(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().broadcast_to((position_ids.shape[0], -1, 1)) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = mint.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class SmolLM3Model(SmolLM3PreTrainedModel): + def __init__(self, config: SmolLM3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = mint.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.CellList( + [SmolLM3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = SmolLM3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = mint.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class SmolLM3ForCausalLM(SmolLM3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = SmolLM3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = mint.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + logits_to_keep: Union[int, ms.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import SmolLM3ForCausalLM + >>> import mindspore as ms + + >>> model = SmolLM3ForCausalLM.from_pretrained("meta-smollm3/SmolLM3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-smollm3/SmolLM3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="np") + + >>> # Generate + >>> generate_ids = model.generate(ms.tensor(inputs.input_ids), max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class SmolLM3ForSequenceClassification(GenericForSequenceClassification, SmolLM3PreTrainedModel): + pass + + +class SmolLM3ForTokenClassification(GenericForTokenClassification, SmolLM3PreTrainedModel): + pass + + +class SmolLM3ForQuestionAnswering(GenericForQuestionAnswering, SmolLM3PreTrainedModel): + base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` + + +__all__ = [ + "SmolLM3PreTrainedModel", + "SmolLM3Model", + "SmolLM3ForCausalLM", + "SmolLM3ForSequenceClassification", + "SmolLM3ForTokenClassification", + "SmolLM3ForQuestionAnswering", +] diff --git a/tests/transformers_tests/models/smollm3/__init__.py b/tests/transformers_tests/models/smollm3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/smollm3/test_modeling_smollm3.py b/tests/transformers_tests/models/smollm3/test_modeling_smollm3.py new file mode 100644 index 0000000000..b95219a6c6 --- /dev/null +++ b/tests/transformers_tests/models/smollm3/test_modeling_smollm3.py @@ -0,0 +1,151 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the MindSpore SmolLM3 model.""" +import inspect + +import numpy as np +import pytest +import torch +import transformers + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import ids_numpy + +# default config of HuggingFaceTB/SmolLM3-3B is bf16 +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] + +if transformers.__version__ >= "4.54.1": + from transformers.models.smollm3.configuration_smollm3 import SmolLM3Config + + class SmolLM3ModelTester: + def __init__( + self, + batch_size=5, + seq_length=20, + ): + self.batch_size = batch_size + self.seq_length = seq_length + + def get_config(self): + return SmolLM3Config() + + def prepare_config_and_inputs(self): + config = self.get_config() + vocab_size = config.vocab_size + input_ids = ids_numpy([self.batch_size, self.seq_length], vocab_size) + attention_mask = np.tril(np.ones_like(input_ids)) + + return config, input_ids, attention_mask + + model_tester = SmolLM3ModelTester() + config, input_ids, attention_mask = model_tester.prepare_config_and_inputs() + + SMOLLM3_CASES = [ + [ + "SmolLM3ForCausalLM", + "transformers.SmolLM3ForCausalLM", + "mindone.transformers.SmolLM3ForCausalLM", + (config,), + {}, + (), + { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + { + "logits": 0, + }, + ], + ] + + @pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in SMOLLM3_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], + ) + def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, + ): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + ms_inputs_kwargs.update({"use_cache": False}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )