Skip to content

Commit 895e5c0

Browse files
authored
【开源实习】Mamba2模型迁移 (#2009)
1 parent 10b74e8 commit 895e5c0

File tree

12 files changed

+1618
-5
lines changed

12 files changed

+1618
-5
lines changed

mindnlp/transformers/models/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
luke,
136136
lxmert,
137137
mamba,
138+
mamba2,
138139
marian,
139140
markuplm,
140141
m2m_100,
@@ -381,6 +382,7 @@
381382
from .lxmert import *
382383
from .m2m_100 import *
383384
from .mamba import *
385+
from .mamba2 import *
384386
from .marian import *
385387
from .markuplm import *
386388
from .maskformer import *
@@ -626,6 +628,7 @@
626628
__all__.extend(lxmert.__all__)
627629
__all__.extend(m2m_100.__all__)
628630
__all__.extend(mamba.__all__)
631+
__all__.extend(mamba2.__all__)
629632
__all__.extend(marian.__all__)
630633
__all__.extend(markuplm.__all__)
631634
__all__.extend(maskformer.__all__)

mindnlp/transformers/models/auto/configuration_auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
("lxmert", "LxmertConfig"),
136136
("m2m_100", "M2M100Config"),
137137
("mamba", "MambaConfig"),
138+
("mamba2", "Mamba2Config"),
138139
("marian", "MarianConfig"),
139140
('markuplm', "MarkupLMConfig"),
140141
("mask2former", "Mask2FormerConfig"),
@@ -353,6 +354,7 @@
353354
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
354355
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
355356
("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
357+
("mamba2", "MAMBA2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
356358
("marian", "MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
357359
("markuplm", "MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
358360
("mask2former", "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -608,6 +610,7 @@
608610
("lxmert", "LXMERT"),
609611
("m2m_100", "M2M100"),
610612
("mamba", "Mamba"),
613+
("mamba2", "Mamba2"),
611614
("marian", "Marian"),
612615
("markuplm", "MarkupLM"),
613616
("mask2former", "Mask2Former"),

mindnlp/transformers/models/auto/modeling_auto.py

+4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
("lxmert", "LxmertModel"),
152152
("m2m_100", "M2M100Model"),
153153
("mamba", "MambaModel"),
154+
("mamba2", "Mamba2Model"),
154155
("marian", "MarianModel"),
155156
("markuplm", "MarkupLMModel"),
156157
("mask2former", "Mask2FormerModel"),
@@ -318,6 +319,7 @@
318319
("luke", "LukeForMaskedLM"),
319320
("lxmert", "LxmertForPreTraining"),
320321
("mamba", "MambaForCausalLM"),
322+
("mamba2", "Mamba2ForCausalLM"),
321323
("mega", "MegaForMaskedLM"),
322324
("megatron-bert", "MegatronBertForPreTraining"),
323325
('minicpm', 'MiniCPMForCausalLM'),
@@ -405,6 +407,7 @@
405407
("luke", "LukeForMaskedLM"),
406408
("m2m_100", "M2M100ForConditionalGeneration"),
407409
("mamba", "MambaForCausalLM"),
410+
("mamba2", "Mamba2ForCausalLM"),
408411
("marian", "MarianMTModel"),
409412
("mega", "MegaForMaskedLM"),
410413
("megatron-bert", "MegatronBertForCausalLM"),
@@ -491,6 +494,7 @@
491494
("jetmoe", "JetMoeForCausalLM"),
492495
("llama", "LlamaForCausalLM"),
493496
("mamba", "MambaForCausalLM"),
497+
("mamba2", "Mamba2ForCausalLM"),
494498
("marian", "MarianForCausalLM"),
495499
("mbart", "MBartForCausalLM"),
496500
("mega", "MegaForCausalLM"),

mindnlp/transformers/models/auto/tokenization_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@
269269
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
270270
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
271271
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
272+
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
272273
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
273274
(
274275
"mbart",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Mamba2 Model.
16+
"""
17+
from . import modeling_mamba2, configuration_mamba2
18+
from .modeling_mamba2 import *
19+
from .configuration_mamba2 import *
20+
21+
__all__ = []
22+
__all__.extend(modeling_mamba2.__all__)
23+
__all__.extend(configuration_mamba2.__all__)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""MAMBA2 configuration"""
16+
17+
import math
18+
19+
from mindnlp.utils import logging
20+
from ...configuration_utils import PretrainedConfig
21+
22+
logger = logging.get_logger(__name__)
23+
24+
class Mamba2Config(PretrainedConfig):
25+
"""
26+
This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2
27+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28+
defaults will yield a similar configuration to that of the MAMBA2
29+
[state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture.
30+
31+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32+
documentation from [`PretrainedConfig`] for more information.
33+
34+
35+
Args:
36+
num_heads (`int`, *optional*, defaults to 128):
37+
Number of heads for the evolution matrices of mamba 2.
38+
head_dim (`int`, *optional*, defaults to 64):
39+
Dimension of each head.
40+
vocab_size (`int`, *optional*, defaults to 32768):
41+
Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the
42+
`inputs_ids` passed when calling [`Mamba2Model`].
43+
hidden_size (`int`, *optional*, defaults to 4096):
44+
Dimensionality of the embeddings and hidden states.
45+
state_size (`int`, *optional*, defaults to 128): shape of the state space latents.
46+
num_hidden_layers (`int`, *optional*, defaults to 64):
47+
Number of hidden layers in the model.
48+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
49+
The epsilon to use in the layer normalization layers.
50+
pad_token_id (`int`, *optional*, defaults to 1):
51+
Padding token id.
52+
bos_token_id (`int`, *optional*, defaults to 0):
53+
The id of the beginning of sentence token in the vocabulary.
54+
eos_token_id (`int`, *optional*, defaults to 2):
55+
The id of the end of sentence token in the vocabulary.
56+
expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
57+
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
58+
n_groups (`int`, *optional*, defaults to 8):
59+
Number of groups for the evolution matrices of mamba 2.
60+
use_bias (`bool`, *optional*, defaults to `False`):
61+
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
62+
use_conv_bias (`bool`, *optional*, defaults to `True`):
63+
Whether or not to use bias in the convolution layer of the mixer block.
64+
hidden_act (`str`, *optional*, defaults to `"silu"`):
65+
The non-linear activation function (function or string) in the decoder.
66+
initializer_range (`float`, *optional*, defaults to 0.1):
67+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68+
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
69+
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
70+
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
71+
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
72+
time_step_min (`float`, *optional*, defaults to 0.001):
73+
Minimum `time_step` used to bound `dt_proj.bias`.
74+
time_step_max (`float`, *optional*, defaults to 0.1):
75+
Maximum `time_step` used to bound `dt_proj.bias`.
76+
time_step_floor (`float`, *optional*, defaults to 0.0001):
77+
Minimum clamping value of the `dt_proj.bias` layer initialization.
78+
time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
79+
Accepted range of time step values.
80+
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
81+
Whether or not to rescale `out_proj` weights when initializing.
82+
use_cache (`bool`, *optional*, defaults to `True`):
83+
Whether or not the cache should be used.
84+
rms_norm (`bool`, *optional*, defaults to `True`):
85+
Whether to use RMS norm or not.
86+
chunk_size (`int`, *optional*, defaults to 256):
87+
Size of the chunks that will comprise the sequence.
88+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
89+
Whether to tie word embeddings or not.
90+
91+
92+
Example:
93+
94+
```python
95+
>>> from transformers import Mamba2Config, Mamba2Model
96+
97+
>>> # Initializing a Mamba2 configuration
98+
>>> configuration = Mamba2Config()
99+
100+
>>> # Initializing a model (with random weights) from the configuration
101+
>>> model = Mamba2Model(configuration)
102+
103+
>>> # Accessing the model configuration
104+
>>> configuration = model.config
105+
```"""
106+
107+
model_type = "mamba2"
108+
109+
def __init__(
110+
self,
111+
num_heads=128,
112+
head_dim=64,
113+
vocab_size=32768,
114+
hidden_size=4096,
115+
state_size=128,
116+
num_hidden_layers=64,
117+
layer_norm_epsilon=1e-5,
118+
pad_token_id=1,
119+
bos_token_id=0,
120+
eos_token_id=2,
121+
expand=2,
122+
conv_kernel=4,
123+
n_groups=8,
124+
use_bias=False,
125+
use_conv_bias=True,
126+
hidden_act="silu",
127+
initializer_range=0.1,
128+
residual_in_fp32=True,
129+
time_step_rank="auto",
130+
time_step_min=0.001,
131+
time_step_max=0.1,
132+
time_step_floor=1e-4,
133+
time_step_limit=(0.0, float("inf")),
134+
rescale_prenorm_residual=False,
135+
use_cache=True,
136+
rms_norm=True,
137+
chunk_size=256,
138+
tie_word_embeddings=False,
139+
**kwargs,
140+
):
141+
self.vocab_size = vocab_size
142+
self.hidden_size = hidden_size
143+
self.state_size = state_size
144+
self.num_hidden_layers = num_hidden_layers
145+
self.layer_norm_epsilon = layer_norm_epsilon
146+
self.conv_kernel = conv_kernel
147+
self.expand = expand
148+
149+
self.bos_token_id = bos_token_id
150+
self.eos_token_id = eos_token_id
151+
self.pad_token_id = pad_token_id
152+
self.use_bias = use_bias
153+
self.use_conv_bias = use_conv_bias
154+
self.hidden_act = hidden_act
155+
self.initializer_range = initializer_range
156+
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
157+
self.time_step_min = time_step_min
158+
self.time_step_max = time_step_max
159+
self.time_step_floor = time_step_floor
160+
self.rescale_prenorm_residual = rescale_prenorm_residual
161+
self.residual_in_fp32 = residual_in_fp32
162+
self.use_cache = use_cache
163+
self.n_groups = n_groups
164+
self.num_heads = num_heads
165+
self.head_dim = head_dim
166+
self.rms_norm = rms_norm
167+
self.state_size = state_size
168+
self.chunk_size = chunk_size
169+
self.time_step_limit = time_step_limit
170+
self.tie_word_embeddings = tie_word_embeddings
171+
172+
super().__init__(
173+
bos_token_id=bos_token_id,
174+
eos_token_id=eos_token_id,
175+
pad_token_id=pad_token_id,
176+
tie_word_embeddings=tie_word_embeddings,
177+
**kwargs,
178+
)
179+
180+
181+
__all__ = ["Mamba2Config"]

0 commit comments

Comments
 (0)