@@ -227,11 +227,11 @@ def rotate_half(x):
227227def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
228228 """Applies Rotary Position Embedding to the query and key tensors.
229229 Args:
230- q (`torch .Tensor`): The query tensor.
231- k (`torch .Tensor`): The key tensor.
232- cos (`torch .Tensor`): The cosine part of the rotary embedding.
233- sin (`torch .Tensor`): The sine part of the rotary embedding.
234- position_ids (`torch .Tensor`):
230+ q (`ms .Tensor`): The query tensor.
231+ k (`ms .Tensor`): The key tensor.
232+ cos (`ms .Tensor`): The cosine part of the rotary embedding.
233+ sin (`ms .Tensor`): The sine part of the rotary embedding.
234+ position_ids (`ms .Tensor`):
235235 The position indices of the tokens corresponding to the query and key tensors. For example, this can be
236236 used to pass offsetted position ids when working with a KV-cache.
237237 unsqueeze_dim (`int`, *optional*, defaults to 1):
@@ -242,7 +242,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
242242 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
243243 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
244244 Returns:
245- `tuple(torch .Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
245+ `tuple(ms .Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
246246 """
247247 # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
248248 # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
@@ -613,13 +613,13 @@ def _flash_attention_forward(
613613 Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
614614 first unpad the input, then computes the attention scores and pad the final attention scores.
615615 Args:
616- query_states (`torch .Tensor`):
616+ query_states (`ms .Tensor`):
617617 Input query states to be passed to Flash Attention API
618- key_states (`torch .Tensor`):
618+ key_states (`ms .Tensor`):
619619 Input key states to be passed to Flash Attention API
620- value_states (`torch .Tensor`):
620+ value_states (`ms .Tensor`):
621621 Input value states to be passed to Flash Attention API
622- attention_mask (`torch .Tensor`):
622+ attention_mask (`ms .Tensor`):
623623 The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
624624 position of padding tokens and 1 for the position of non-padding tokens.
625625 dropout (`int`, *optional*):
@@ -646,6 +646,8 @@ def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):
646646 super ().__init__ (config , layer_idx )
647647 compute_dtype = str_to_dtype (config .mindspore_dtype )
648648
649+ self .is_first_iteration = True
650+
649651 self .infer_attention = InferAttention (
650652 config .num_attention_heads ,
651653 self .head_dim ,
@@ -658,14 +660,12 @@ def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):
658660 next_tokens = 0 ,
659661 block_size = 32 ,
660662 num_blocks = 1024 ,
661- is_dynamic = True ,
663+ is_dynamic = True if not self . is_first_iteration else False ,
662664 use_flash_attention = True ,
663665 use_rope_rotary_emb = False ,
664666 compute_dtype = compute_dtype ,
665667 )
666668
667- self .is_first_iteration = True
668-
669669 def construct (
670670 self ,
671671 hidden_states : ms .Tensor ,
@@ -713,9 +713,9 @@ def construct(
713713 value_states = value_states .swapaxes (1 , 2 ).reshape (bsz , q_len , - 1 )
714714
715715 if not self .is_first_iteration :
716- query_states = query_states [:, - 1 , :].reshape (bsz , 1 , - 1 )
717- key_states = key_states [:, - 1 , :].reshape (bsz , 1 , - 1 )
718- value_states = value_states [:, - 1 , :].reshape (bsz , 1 , - 1 )
716+ query_states = query_states [:, - 1 , :].reshape (bsz , 1 , self . num_heads * self . head_dim )
717+ key_states = key_states [:, - 1 , :].reshape (bsz , 1 , self . num_key_value_heads * self . head_dim )
718+ value_states = value_states [:, - 1 , :].reshape (bsz , 1 , self . num_key_value_heads * self . head_dim )
719719
720720 attn_output = self .infer_attention (
721721 query_states ,
@@ -775,8 +775,8 @@ def construct(
775775 ) -> Tuple [ms .Tensor , Optional [Tuple [ms .Tensor , ms .Tensor ]]]:
776776 """
777777 Args:
778- hidden_states (`torch.FloatTensor `): input to the layer of shape `(batch, seq_len, embed_dim)`
779- attention_mask (`torch.FloatTensor `, *optional*):
778+ hidden_states (`ms.Tensor `): input to the layer of shape `(batch, seq_len, embed_dim)`
779+ attention_mask (`ms.Tensor `, *optional*):
780780 attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
781781 query_sequence_length, key_sequence_length)` if default attention is used.
782782 output_attentions (`bool`, *optional*):
@@ -785,7 +785,7 @@ def construct(
785785 use_cache (`bool`, *optional*):
786786 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
787787 (see `past_key_values`).
788- past_key_value (`Tuple(torch.FloatTensor )`, *optional*): cached past key and value projection states
788+ past_key_value (`Tuple(ms.Tensor )`, *optional*): cached past key and value projection states
789789 """
790790 if "padding_mask" in kwargs :
791791 warnings .warn (
@@ -866,13 +866,13 @@ def _init_weights(self, module):
866866
867867MINICPM_INPUTS_DOCSTRING = r"""
868868 Args:
869- input_ids (`torch.LongTensor ` of shape `(batch_size, sequence_length)`):
869+ input_ids (`ms.Tensor ` of shape `(batch_size, sequence_length)`):
870870 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
871871 it.
872872 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
873873 [`PreTrainedTokenizer.__call__`] for details.
874874 [What are input IDs?](../glossary#input-ids)
875- attention_mask (`torch .Tensor` of shape `(batch_size, sequence_length)`, *optional*):
875+ attention_mask (`ms .Tensor` of shape `(batch_size, sequence_length)`, *optional*):
876876 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
877877 - 1 for tokens that are **not masked**,
878878 - 0 for tokens that are **masked**.
@@ -886,25 +886,25 @@ def _init_weights(self, module):
886886 information on the default strategy.
887887 - 1 indicates the head is **not masked**,
888888 - 0 indicates the head is **masked**.
889- position_ids (`torch.LongTensor ` of shape `(batch_size, sequence_length)`, *optional*):
889+ position_ids (`ms.Tensor ` of shape `(batch_size, sequence_length)`, *optional*):
890890 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
891891 config.n_positions - 1]`.
892892 [What are position IDs?](../glossary#position-ids)
893- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor ))`, *optional*):
893+ past_key_values (`Cache` or `tuple(tuple(ms.Tensor ))`, *optional*):
894894 Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
895895 blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
896896 returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
897897 Two formats are allowed:
898898 - a [`~cache_utils.Cache`] instance;
899- - Tuple of `tuple(torch.FloatTensor )` of length `config.n_layers`, with each tuple having 2 tensors of
899+ - Tuple of `tuple(ms.Tensor )` of length `config.n_layers`, with each tuple having 2 tensors of
900900 shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
901901 cache format.
902902 The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
903903 legacy cache format will be returned.
904904 If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
905905 have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
906906 of shape `(batch_size, sequence_length)`.
907- inputs_embeds (`torch.FloatTensor ` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
907+ inputs_embeds (`ms.Tensor ` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
908908 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
909909 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
910910 model's internal embedding lookup matrix.
@@ -1109,20 +1109,22 @@ def __init__(self, config):
11091109 if self .config ._attn_implementation == "paged_attention" :
11101110 compute_dtype = str_to_dtype (config .mindspore_dtype )
11111111
1112+ self .is_first_iteration = True
1113+
11121114 self .freqs_mgr = FreqsMgr (
11131115 head_dim = config .hidden_size // config .num_attention_heads ,
11141116 seq_length = config .max_position_embeddings ,
11151117 max_position_embedding = config .max_position_embeddings ,
11161118 rotary_dtype = compute_dtype ,
11171119 theta = config .rope_theta ,
1118- is_dynamic = True ,
1120+ is_dynamic = True if not self . is_first_iteration else False ,
11191121 )
11201122
11211123 self .casual_mask = LowerTriangularMaskWithDynamic (
11221124 seq_length = config .max_position_embeddings ,
11231125 batch_size = 1 ,
11241126 compute_type = compute_dtype ,
1125- is_dynamic = True ,
1127+ is_dynamic = True if not self . is_first_iteration else False ,
11261128 pad_token_id = config .pad_token_id ,
11271129 use_flash_attention = True ,
11281130 use_attn_mask_compression = False ,
@@ -1131,8 +1133,6 @@ def __init__(self, config):
11311133 chunk_prefill = False ,
11321134 )
11331135
1134- self .is_first_iteration = True
1135-
11361136 def get_input_embeddings (self ):
11371137 return self .model .embed_tokens
11381138
@@ -1215,7 +1215,7 @@ def construct(
12151215 ) -> Union [Tuple , CausalLMOutputWithPast ]:
12161216 r"""
12171217 Args:
1218- labels (`torch.LongTensor ` of shape `(batch_size, sequence_length)`, *optional*):
1218+ labels (`ms.Tensor ` of shape `(batch_size, sequence_length)`, *optional*):
12191219 Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
12201220 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
12211221 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
@@ -1542,7 +1542,7 @@ def construct(
15421542 return_dict : Optional [bool ] = None ,
15431543 ) -> Union [Tuple , SequenceClassifierOutputWithPast ]:
15441544 r"""
1545- labels (`torch.LongTensor ` of shape `(batch_size,)`, *optional*):
1545+ labels (`ms.Tensor ` of shape `(batch_size,)`, *optional*):
15461546 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
15471547 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
15481548 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
0 commit comments