Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 290 additions & 0 deletions src/mcore_bridge/model/modules/multi_latent_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import torch
import torch.nn.functional as F
from megatron.core import parallel_state, tensor_parallel
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
scatter_to_sequence_parallel_region)
from megatron.core.transformer.multi_latent_attention import MLASelfAttention as McoreMLASelfAttention
from megatron.core.utils import deprecate_inference_params


class MLASelfAttention(McoreMLASelfAttention):

def get_query_key_value_tensors(
self,
hidden_states,
key_value_states=None,
position_ids=None,
packed_seq_params=None,
inference_context=None,
rotary_pos_emb=None,
*,
inference_params=None,
):
"""
Derives `query`, `key` and `value` tensors from `hidden_states`.
"""
# s = sequence length, b = batch size, h = hidden size, n = num attention heads
# Attention heads [s, b, n*h]
assert (hidden_states.ndim == 3), f'hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D'

inference_context = deprecate_inference_params(inference_context, inference_params)

if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None

# =========================================
# QKV down projection and layernorm
# =========================================
if self.config.q_lora_rank is not None:
# if linear_q_down_proj is ColumnParallelLinear:
# q_compressed: [s, b, q_lora_rank / TP]
# elif linear_q_down_proj is Linear:
# q_compressed: [s / TP, b, q_lora_rank]
q_compressed, _ = self.linear_q_down_proj(hidden_states)

# When output is sharded (ColumnParallelLinear), two things are needed to be
# identical to a normal Linear.
# 1. Manually gather output to restore output dim q_lora_rank;
# 2. Scatter sequence back to s / TP if sequence-parallel since it was
# gathered by ColumnParallelLinear.
if q_compressed.size(-1) != self.config.q_lora_rank:
q_compressed = gather_from_tensor_model_parallel_region(q_compressed)
if self.config.sequence_parallel:
q_compressed = scatter_to_sequence_parallel_region(q_compressed)

q_compressed = self.q_layernorm(q_compressed)
else:
q_compressed = hidden_states

# if linear_kv_down_proj is ColumnParallelLinear:
# kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP]
# elif linear_kv_down_proj is Linear:
# kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)]
kv_combined, _ = self.linear_kv_down_proj(hidden_states)
if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim:
# kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)]
kv_combined = gather_from_tensor_model_parallel_region(kv_combined)
# kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim]
kv_compressed, k_pos_emb = torch.split(
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1)
if self.config.sequence_parallel:
# kv_compressed:[s / TP, b, kv_lora_rank]
kv_compressed = scatter_to_sequence_parallel_region(kv_compressed)
else:
# kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim]
kv_compressed, k_pos_emb = torch.split(
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1)
if parallel_state.get_tensor_model_parallel_world_size() > 1:
# k_pos_emb: [s, b, qk_pos_emb_head_dim]
k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)

kv_compressed = self.kv_layernorm(kv_compressed)

# =========================================
# QKV up projection and RoPE apply
# =========================================
def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb):
"""
Apply the up projection and RoPE to the query and key.
When sequence packing enabled, the input tensors adopt a packed shape of [t, ...];
otherwise, they maintain the unpacked shape [s, b, ...]. In subsequent code comments,
we uniformly use [num_tokens, ...] to denote [s, b, ...] or [t, ...] for two cases.
"""
if self.config.q_lora_rank is not None:
# q_compressed: [num_tokens, q_lora_rank]
# q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)]
q, _ = self.linear_q_up_proj(q_compressed)
else:
# q_compressed: [num_tokens, hidden_size]
# q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)]
q, _ = self.linear_q_proj(q_compressed)

# q: [num_tokens, n, q_head_dim]
q = q.view(*q.size()[:-1], self.num_attention_heads_per_partition, self.q_head_dim)

# kv: [num_tokens, n * (qk_head_dim + v_head_dim)]
kv, _ = self.linear_kv_up_proj(kv_compressed)

# kv: [num_tokens, n, (qk_head_dim + v_head_dim)]
kv = kv.view(
*kv.size()[:-1],
self.num_attention_heads_per_partition,
self.config.qk_head_dim + self.config.v_head_dim,
)

q_len = q.size()[0]
if inference_context is not None:
# add offset to the sequence start for inference
sequence_start = inference_context.sequence_len_offset
sequence_end = sequence_start + q_len
rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]
# Remove the else branch to fix cp.

# [num_tokens, qk_pos_emb_head_dim] -> [num_tokens, 1, qk_pos_emb_head_dim]
k_pos_emb = torch.unsqueeze(k_pos_emb, -2)

# q_no_pe: [num_tokens, n, qk_head_dim]
# q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim]
q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1)

# k_no_pe: [num_tokens, n, qk_head_dim]
# value: [num_tokens, n, v_head_dim]
k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1)
# This function will be patched and supports mscale.
# q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim]
q_pos_emb = apply_rotary_pos_emb(
q_pos_emb,
rotary_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_q,
)
# k_pos_emb:[num_tokens, 1, qk_pos_emb_head_dim]
k_pos_emb = apply_rotary_pos_emb(
k_pos_emb,
rotary_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_kv,
)

# query: [num_tokens, n, (qk_head_dim + v_head_dim)]
query = torch.cat([q_no_pe, q_pos_emb], dim=-1)

# key: [num_tokens, n, (qk_head_dim + v_head_dim)]Add commentMore actions
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment contains trailing junk text "Add commentMore actions", which appears to be a copy-paste artifact from a code review UI.

Suggested change
# key: [num_tokens, n, (qk_head_dim + v_head_dim)]Add commentMore actions
# key: [num_tokens, n, (qk_head_dim + v_head_dim)]

Comment thread
Jintao-Huang marked this conversation as resolved.
Outdated
if k_pos_emb.ndim == 4:
k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)
else:
assert k_pos_emb.ndim == 3
k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1)
key = torch.cat([k_no_pe, k_pos_emb], dim=-1)

query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
return query, key, value

if packed_seq_params is not None:
# If sequence packing, TE expect [t, h, d] shaped qkv input.
# In Megatron-Core, the qkv shape is [t, 1, h, d].
# So we need to reshape qkv from [t, 1, h, d] to [t, h, d].
q_compressed = q_compressed.squeeze(1)
kv_compressed = kv_compressed.squeeze(1)
k_pos_emb = k_pos_emb.squeeze(1)

if self.recompute_up_proj:
self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput()
query, key, value = self.qkv_up_checkpoint.checkpoint(qkv_up_proj_and_rope_apply, q_compressed,
kv_compressed, k_pos_emb, rotary_pos_emb)
else:
query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb)

return query, key, value, q_compressed, kv_compressed

def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_context=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
rotary_pos_cos_sin=None,
attention_bias=None,
packed_seq_params=None,
position_ids=None,
sequence_len_offset=None,
*,
inference_params=None,
):
"""Forward pass for multi-latent attention"""
assert attention_bias is None, 'Attention bias should not be passed into MLA.'
assert (rotary_pos_cos is None and rotary_pos_sin is None), 'MLA does not support Flash Decoding'

# hidden_states: [sq, b, h]

inference_context = deprecate_inference_params(inference_context, inference_params)

# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
# query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128]
query, key, value, q_compressed, kv_compressed = self.get_query_key_value_tensors(
hidden_states,
key_value_states,
position_ids,
packed_seq_params,
rotary_pos_emb=rotary_pos_emb,
inference_context=inference_context,
)

# ===================================================
# Adjust key, value for inference
# ===================================================
# rotary_pos_emb = None
query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference(
inference_context, query, key, value, rotary_pos_emb=None)

# TODO: Currently, TE can only accept contiguous tensors for MLA
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
Comment on lines +236 to +239
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These calls to .contiguous() are redundant because get_query_key_value_tensors already ensures that the returned query, key, and value tensors are contiguous (see lines 165-167). Removing these redundant operations can slightly improve efficiency.


# ==================================
# core attention computation
# ==================================
# Need corresponding TE change
thd_qkv_format = packed_seq_params and packed_seq_params.qkv_format == 'thd'
v_dim = value.shape[-1]
if thd_qkv_format and query.shape[-1] != v_dim:
value = F.pad(value, [0, query.shape[-1] - v_dim])
self.core_attention.hidden_size_per_attention_head_v = value.shape[-1]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying self.core_attention.hidden_size_per_attention_head_v during the forward pass is a side effect that mutates the module state. This can lead to issues in multi-threaded environments or when the module is shared. Ideally, this configuration should be handled during initialization or passed as an argument to the attention computation if it varies (though for MLA it should be constant).

if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query, key, value, attention_mask, packed_seq_params=packed_seq_params)
else:
extra_kwargs = {}
if self.config.experimental_attention_variant == 'dsa':
# For dsa we need to pass in the original hidden states and the compressed
# query representation.
extra_kwargs['x'] = hidden_states
extra_kwargs['qr'] = q_compressed
# for easy injection of rotary_pos_emb (patch)
packed_seq_params = (packed_seq_params, rotary_pos_emb)
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
packed_seq_params=packed_seq_params,
attn_mask_type=attn_mask_type,
**extra_kwargs,
)
if thd_qkv_format:
if core_attn_out.ndim == 2:
core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-1], -1, value.shape[-1])
if query.shape[-1] != v_dim:
core_attn_out = core_attn_out[..., :v_dim]
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)

if self.recompute_up_proj:
assert self.qkv_up_checkpoint is not None
self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out)
self.qkv_up_checkpoint = None

# =================
# Output. [sq, b, h]
# =================
output, bias = self.linear_proj(core_attn_out)

return output, bias
Loading
Loading