-
Notifications
You must be signed in to change notification settings - Fork 14
refactor MLA #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor MLA #91
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,290 @@ | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from megatron.core import parallel_state, tensor_parallel | ||
| from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb | ||
| from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, | ||
| gather_from_tensor_model_parallel_region, | ||
| scatter_to_sequence_parallel_region) | ||
| from megatron.core.transformer.multi_latent_attention import MLASelfAttention as McoreMLASelfAttention | ||
| from megatron.core.utils import deprecate_inference_params | ||
|
|
||
|
|
||
| class MLASelfAttention(McoreMLASelfAttention): | ||
|
|
||
| def get_query_key_value_tensors( | ||
| self, | ||
| hidden_states, | ||
| key_value_states=None, | ||
| position_ids=None, | ||
| packed_seq_params=None, | ||
| inference_context=None, | ||
| rotary_pos_emb=None, | ||
| *, | ||
| inference_params=None, | ||
| ): | ||
| """ | ||
| Derives `query`, `key` and `value` tensors from `hidden_states`. | ||
| """ | ||
| # s = sequence length, b = batch size, h = hidden size, n = num attention heads | ||
| # Attention heads [s, b, n*h] | ||
| assert (hidden_states.ndim == 3), f'hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D' | ||
|
|
||
| inference_context = deprecate_inference_params(inference_context, inference_params) | ||
|
|
||
| if packed_seq_params is not None: | ||
| cu_seqlens_q = packed_seq_params.cu_seqlens_q | ||
| cu_seqlens_kv = packed_seq_params.cu_seqlens_kv | ||
| else: | ||
| cu_seqlens_q = cu_seqlens_kv = None | ||
|
|
||
| # ========================================= | ||
| # QKV down projection and layernorm | ||
| # ========================================= | ||
| if self.config.q_lora_rank is not None: | ||
| # if linear_q_down_proj is ColumnParallelLinear: | ||
| # q_compressed: [s, b, q_lora_rank / TP] | ||
| # elif linear_q_down_proj is Linear: | ||
| # q_compressed: [s / TP, b, q_lora_rank] | ||
| q_compressed, _ = self.linear_q_down_proj(hidden_states) | ||
|
|
||
| # When output is sharded (ColumnParallelLinear), two things are needed to be | ||
| # identical to a normal Linear. | ||
| # 1. Manually gather output to restore output dim q_lora_rank; | ||
| # 2. Scatter sequence back to s / TP if sequence-parallel since it was | ||
| # gathered by ColumnParallelLinear. | ||
| if q_compressed.size(-1) != self.config.q_lora_rank: | ||
| q_compressed = gather_from_tensor_model_parallel_region(q_compressed) | ||
| if self.config.sequence_parallel: | ||
| q_compressed = scatter_to_sequence_parallel_region(q_compressed) | ||
|
|
||
| q_compressed = self.q_layernorm(q_compressed) | ||
| else: | ||
| q_compressed = hidden_states | ||
|
|
||
| # if linear_kv_down_proj is ColumnParallelLinear: | ||
| # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP] | ||
| # elif linear_kv_down_proj is Linear: | ||
| # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] | ||
| kv_combined, _ = self.linear_kv_down_proj(hidden_states) | ||
| if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: | ||
| # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] | ||
| kv_combined = gather_from_tensor_model_parallel_region(kv_combined) | ||
| # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] | ||
| kv_compressed, k_pos_emb = torch.split( | ||
| kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1) | ||
| if self.config.sequence_parallel: | ||
| # kv_compressed:[s / TP, b, kv_lora_rank] | ||
| kv_compressed = scatter_to_sequence_parallel_region(kv_compressed) | ||
| else: | ||
| # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim] | ||
| kv_compressed, k_pos_emb = torch.split( | ||
| kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1) | ||
| if parallel_state.get_tensor_model_parallel_world_size() > 1: | ||
| # k_pos_emb: [s, b, qk_pos_emb_head_dim] | ||
| k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb) | ||
|
|
||
| kv_compressed = self.kv_layernorm(kv_compressed) | ||
|
|
||
| # ========================================= | ||
| # QKV up projection and RoPE apply | ||
| # ========================================= | ||
| def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb): | ||
| """ | ||
| Apply the up projection and RoPE to the query and key. | ||
| When sequence packing enabled, the input tensors adopt a packed shape of [t, ...]; | ||
| otherwise, they maintain the unpacked shape [s, b, ...]. In subsequent code comments, | ||
| we uniformly use [num_tokens, ...] to denote [s, b, ...] or [t, ...] for two cases. | ||
| """ | ||
| if self.config.q_lora_rank is not None: | ||
| # q_compressed: [num_tokens, q_lora_rank] | ||
| # q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)] | ||
| q, _ = self.linear_q_up_proj(q_compressed) | ||
| else: | ||
| # q_compressed: [num_tokens, hidden_size] | ||
| # q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)] | ||
| q, _ = self.linear_q_proj(q_compressed) | ||
|
|
||
| # q: [num_tokens, n, q_head_dim] | ||
| q = q.view(*q.size()[:-1], self.num_attention_heads_per_partition, self.q_head_dim) | ||
|
|
||
| # kv: [num_tokens, n * (qk_head_dim + v_head_dim)] | ||
| kv, _ = self.linear_kv_up_proj(kv_compressed) | ||
|
|
||
| # kv: [num_tokens, n, (qk_head_dim + v_head_dim)] | ||
| kv = kv.view( | ||
| *kv.size()[:-1], | ||
| self.num_attention_heads_per_partition, | ||
| self.config.qk_head_dim + self.config.v_head_dim, | ||
| ) | ||
|
|
||
| q_len = q.size()[0] | ||
| if inference_context is not None: | ||
| # add offset to the sequence start for inference | ||
| sequence_start = inference_context.sequence_len_offset | ||
| sequence_end = sequence_start + q_len | ||
| rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end] | ||
| # Remove the else branch to fix cp. | ||
|
|
||
| # [num_tokens, qk_pos_emb_head_dim] -> [num_tokens, 1, qk_pos_emb_head_dim] | ||
| k_pos_emb = torch.unsqueeze(k_pos_emb, -2) | ||
|
|
||
| # q_no_pe: [num_tokens, n, qk_head_dim] | ||
| # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim] | ||
| q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) | ||
|
|
||
| # k_no_pe: [num_tokens, n, qk_head_dim] | ||
| # value: [num_tokens, n, v_head_dim] | ||
| k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1) | ||
| # This function will be patched and supports mscale. | ||
| # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim] | ||
| q_pos_emb = apply_rotary_pos_emb( | ||
| q_pos_emb, | ||
| rotary_pos_emb, | ||
| config=self.config, | ||
| cu_seqlens=cu_seqlens_q, | ||
| ) | ||
| # k_pos_emb:[num_tokens, 1, qk_pos_emb_head_dim] | ||
| k_pos_emb = apply_rotary_pos_emb( | ||
| k_pos_emb, | ||
| rotary_pos_emb, | ||
| config=self.config, | ||
| cu_seqlens=cu_seqlens_kv, | ||
| ) | ||
|
|
||
| # query: [num_tokens, n, (qk_head_dim + v_head_dim)] | ||
| query = torch.cat([q_no_pe, q_pos_emb], dim=-1) | ||
|
|
||
| # key: [num_tokens, n, (qk_head_dim + v_head_dim)]Add commentMore actions | ||
|
Jintao-Huang marked this conversation as resolved.
Outdated
|
||
| if k_pos_emb.ndim == 4: | ||
| k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1) | ||
| else: | ||
| assert k_pos_emb.ndim == 3 | ||
| k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1) | ||
| key = torch.cat([k_no_pe, k_pos_emb], dim=-1) | ||
|
|
||
| query = query.contiguous() | ||
| key = key.contiguous() | ||
| value = value.contiguous() | ||
| return query, key, value | ||
|
|
||
| if packed_seq_params is not None: | ||
| # If sequence packing, TE expect [t, h, d] shaped qkv input. | ||
| # In Megatron-Core, the qkv shape is [t, 1, h, d]. | ||
| # So we need to reshape qkv from [t, 1, h, d] to [t, h, d]. | ||
| q_compressed = q_compressed.squeeze(1) | ||
| kv_compressed = kv_compressed.squeeze(1) | ||
| k_pos_emb = k_pos_emb.squeeze(1) | ||
|
|
||
| if self.recompute_up_proj: | ||
| self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput() | ||
| query, key, value = self.qkv_up_checkpoint.checkpoint(qkv_up_proj_and_rope_apply, q_compressed, | ||
| kv_compressed, k_pos_emb, rotary_pos_emb) | ||
| else: | ||
| query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) | ||
|
|
||
| return query, key, value, q_compressed, kv_compressed | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states, | ||
| attention_mask, | ||
| key_value_states=None, | ||
| inference_context=None, | ||
| rotary_pos_emb=None, | ||
| rotary_pos_cos=None, | ||
| rotary_pos_sin=None, | ||
| rotary_pos_cos_sin=None, | ||
| attention_bias=None, | ||
| packed_seq_params=None, | ||
| position_ids=None, | ||
| sequence_len_offset=None, | ||
| *, | ||
| inference_params=None, | ||
| ): | ||
| """Forward pass for multi-latent attention""" | ||
| assert attention_bias is None, 'Attention bias should not be passed into MLA.' | ||
| assert (rotary_pos_cos is None and rotary_pos_sin is None), 'MLA does not support Flash Decoding' | ||
|
|
||
| # hidden_states: [sq, b, h] | ||
|
|
||
| inference_context = deprecate_inference_params(inference_context, inference_params) | ||
|
|
||
| # ===================== | ||
| # Query, Key, and Value | ||
| # ===================== | ||
| # Get the query, key and value tensors based on the type of attention - | ||
| # self or cross attn. | ||
| # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] | ||
| query, key, value, q_compressed, kv_compressed = self.get_query_key_value_tensors( | ||
| hidden_states, | ||
| key_value_states, | ||
| position_ids, | ||
| packed_seq_params, | ||
| rotary_pos_emb=rotary_pos_emb, | ||
| inference_context=inference_context, | ||
| ) | ||
|
|
||
| # =================================================== | ||
| # Adjust key, value for inference | ||
| # =================================================== | ||
| # rotary_pos_emb = None | ||
| query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference( | ||
| inference_context, query, key, value, rotary_pos_emb=None) | ||
|
|
||
| # TODO: Currently, TE can only accept contiguous tensors for MLA | ||
| query = query.contiguous() | ||
| key = key.contiguous() | ||
| value = value.contiguous() | ||
|
Comment on lines
+236
to
+239
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # ================================== | ||
| # core attention computation | ||
| # ================================== | ||
| # Need corresponding TE change | ||
| thd_qkv_format = packed_seq_params and packed_seq_params.qkv_format == 'thd' | ||
| v_dim = value.shape[-1] | ||
| if thd_qkv_format and query.shape[-1] != v_dim: | ||
| value = F.pad(value, [0, query.shape[-1] - v_dim]) | ||
| self.core_attention.hidden_size_per_attention_head_v = value.shape[-1] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modifying |
||
| if self.checkpoint_core_attention and self.training: | ||
| core_attn_out = self._checkpointed_attention_forward( | ||
| query, key, value, attention_mask, packed_seq_params=packed_seq_params) | ||
| else: | ||
| extra_kwargs = {} | ||
| if self.config.experimental_attention_variant == 'dsa': | ||
| # For dsa we need to pass in the original hidden states and the compressed | ||
| # query representation. | ||
| extra_kwargs['x'] = hidden_states | ||
| extra_kwargs['qr'] = q_compressed | ||
| # for easy injection of rotary_pos_emb (patch) | ||
| packed_seq_params = (packed_seq_params, rotary_pos_emb) | ||
| core_attn_out = self.core_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| attention_mask, | ||
| packed_seq_params=packed_seq_params, | ||
| attn_mask_type=attn_mask_type, | ||
| **extra_kwargs, | ||
| ) | ||
| if thd_qkv_format: | ||
| if core_attn_out.ndim == 2: | ||
| core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-1], -1, value.shape[-1]) | ||
| if query.shape[-1] != v_dim: | ||
| core_attn_out = core_attn_out[..., :v_dim] | ||
| # reshape to same output shape as unpacked case | ||
| # (t, np, hn) -> (t, b=1, h=np*hn) | ||
| # t is the pack size = sum (sq_i) | ||
| # note that batch is a dummy dimension in the packed case | ||
| core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) | ||
|
|
||
| if self.recompute_up_proj: | ||
| assert self.qkv_up_checkpoint is not None | ||
| self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out) | ||
| self.qkv_up_checkpoint = None | ||
|
|
||
| # ================= | ||
| # Output. [sq, b, h] | ||
| # ================= | ||
| output, bias = self.linear_proj(core_attn_out) | ||
|
|
||
| return output, bias | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment contains trailing junk text "Add commentMore actions", which appears to be a copy-paste artifact from a code review UI.