Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,5 +345,6 @@
XLMRobertaXLModel,
XLMRobertaXLPreTrainedModel,
)
from .models.zamba2 import Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model, Zamba2PreTrainedModel
from .pipelines import TextGenerationPipeline, pipeline
from .processing_utils import ProcessorMixin
2 changes: 2 additions & 0 deletions mindone/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
("whisper", "WhisperConfig"),
("xlm-roberta", "XLMRobertaConfig"),
("xlm-roberta-xl", "XLMRobertaXLConfig"),
("zamba2", "Zamba2Config"),
]
)

Expand Down Expand Up @@ -151,6 +152,7 @@
("whisper", "Whisper"),
("xlm-roberta", "XLM-RoBERTa"),
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
("zamba2", "Zamba2"),
]
)

Expand Down
3 changes: 3 additions & 0 deletions mindone/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
("whisper", "WhisperModel"),
("xlm-roberta", "XLMRobertaModel"),
("xlm-roberta-xl", "XLMRobertaXLModel"),
("zamba2", "Zamba2Model"),
]
)

Expand Down Expand Up @@ -171,6 +172,7 @@
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("zamba2", "Zamba2ForCausalLM"),
]
)

Expand Down Expand Up @@ -326,6 +328,7 @@
("t5", "T5ForSequenceClassification"),
("umt5", "UMT5ForSequenceClassification"),
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
("zamba2", "Zamba2ForSequenceClassification"),
]
)

Expand Down
1 change: 1 addition & 0 deletions mindone/transformers/models/zamba2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modeling_zamba2 import Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model, Zamba2PreTrainedModel
1,618 changes: 1,618 additions & 0 deletions mindone/transformers/models/zamba2/modeling_zamba2.py

Large diffs are not rendered by default.

Empty file.
235 changes: 235 additions & 0 deletions tests/transformers_tests/models/zamba2/test_modeling_zamba2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like
# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs, outputs_map].
#
# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective
# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically
# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks.
#
# In cases where models have unique initialization procedures or require testing with specialized output formats,
# it is necessary to develop distinct, dedicated test cases.

import inspect

import numpy as np
import pytest
import torch
from transformers import Zamba2Config

import mindspore as ms

from tests.modeling_test_utils import (
MS_DTYPE_MAPPING,
PT_DTYPE_MAPPING,
compute_diffs,
generalized_parse_args,
get_modules,
)
from tests.transformers_tests.models.modeling_common import ids_numpy

DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2}
MODES = [1]


class Zamba2ModelTester:
config_class = Zamba2Config

def __init__(
self,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=54,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
self.head_dim = self.hidden_size // self.num_attention_heads

# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size)

input_mask = None
if self.use_input_mask:
input_mask = np.tril(np.ones_like(input_ids))

token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_numpy([self.batch_size, self.seq_length], self.type_vocab_size)

sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_numpy([self.batch_size], self.type_sequence_label_size)
token_labels = ids_numpy([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_numpy([self.batch_size], self.num_choices)

config = self.get_config()

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

def get_config(self):
return self.config_class(
attn_implementation="eager",
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
head_dim=self.head_dim,
)


model_tester = Zamba2ModelTester()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = model_tester.prepare_config_and_inputs()


Zamba2_CASES = [
[
"Zamba2Model",
"transformers.Zamba2Model",
"mindone.transformers.Zamba2Model",
(config,),
{},
(input_ids,),
{
"attention_mask": input_mask,
},
{
"last_hidden_state": 0,
},
],
]


# transformers need >= 4.41.2
@pytest.mark.parametrize(
"name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode",
[
case
+ [
dtype,
]
+ [
mode,
]
for case in Zamba2_CASES
for dtype in DTYPE_AND_THRESHOLDS.keys()
for mode in MODES
],
)
def test_named_modules(
name,
pt_module,
ms_module,
init_args,
init_kwargs,
inputs_args,
inputs_kwargs,
outputs_map,
dtype,
mode,
):
ms.set_context(mode=mode)

(
pt_model,
ms_model,
pt_dtype,
ms_dtype,
) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs)
pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args(
pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs
)

# set `hidden_dtype` if requiring, for some modules always compute in float
# precision and require specific `hidden_dtype` to cast before return
if "hidden_dtype" in inspect.signature(pt_model.forward).parameters:
pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]})
ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]})
if mode == 0:
ms_inputs_kwargs.update({"use_cache": False})
with torch.no_grad():
pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs)
ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs)
# print("ms:", ms_outputs)
# print("pt:", pt_outputs)
if outputs_map:
pt_outputs_n = []
ms_outputs_n = []
for pt_key, ms_idx in outputs_map.items():
# print("===map", pt_key, ms_idx)
pt_output = getattr(pt_outputs, pt_key)
ms_output = ms_outputs[ms_idx]
if isinstance(pt_output, (list, tuple)):
pt_outputs_n += list(pt_output)
ms_outputs_n += list(ms_output)
else:
pt_outputs_n.append(pt_output)
ms_outputs_n.append(ms_output)
diffs = compute_diffs(pt_outputs_n, ms_outputs_n)
else:
diffs = compute_diffs(pt_outputs, ms_outputs)

THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype]
assert (np.array(diffs) < THRESHOLD).all(), (
f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, "
f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}"
)