Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/mcore_bridge/model/mm_gpts/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def _forward_core_attention(
return core_attn_out

def _apply_rotary(self, query, key, rotary_pos_emb, packed_seq_params):
attention_scaling = self.config.attention_scaling
if not self.is_sliding:
self.config.attention_scaling = self.config.full_attention_scaling
Comment on lines +239 to +241
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Mutating the shared self.config object during the forward pass is not thread-safe and can lead to race conditions or incorrect behavior if multiple layers are executed concurrently or if the config is accessed elsewhere. This is particularly problematic in Megatron-Core where the same config instance is often shared across all layers. Instead of temporarily swapping the value in the config, you should determine the correct scaling factor locally and pass it as the mscale argument to the apply_rotary_pos_emb calls.

nvtx_range_push(suffix='rotary_pos_emb')
q_pos_emb, k_pos_emb = rotary_pos_emb

Expand Down Expand Up @@ -269,6 +272,7 @@ def _apply_rotary(self, query, key, rotary_pos_emb, packed_seq_params):
cp_group=self.pg_collection.cp,
)
nvtx_range_pop(suffix='rotary_pos_emb')
self.config.attention_scaling = attention_scaling
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This restoration of self.config.attention_scaling is part of the problematic in-place mutation pattern. If an exception occurs during the rotary embedding application, the config might be left in an inconsistent state for subsequent layers or iterations. Passing the scaling factor as an argument to the rotary application functions would eliminate the need for this state management.

return query, key

def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -551,7 +555,7 @@ def _set_inv_freq(self):
rope_scaling = self.config.rope_scaling
self.config.rope_scaling = rope_scaling['sliding_attention']
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config)
assert attention_scaling == 1, 'not support'
self.config.attention_scaling = attention_scaling
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
# full
self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb)
Expand All @@ -561,9 +565,8 @@ def _set_inv_freq(self):
kwargs['head_dim_key'] = 'global_head_dim'
new_inv_freq, attention_scaling = get_rope_inv_freq(
self.config, text_config=self.config.hf_config.text_config, **kwargs)
assert attention_scaling == 1, 'not support'
self.full_rotary_pos_emb.inv_freq = new_inv_freq
self.config.attention_scaling = attention_scaling
self.config.full_attention_scaling = attention_scaling
Comment thread
Jintao-Huang marked this conversation as resolved.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The attribute full_attention_scaling is being dynamically added to the config object at runtime. It should be explicitly defined in the ModelConfig class in src/mcore_bridge/config/model_config.py to ensure consistency, proper documentation, and to avoid potential AttributeError if accessed before this initialization step.


self.config.rope_scaling = rope_scaling

Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _apply_rotary_pos_emb_bshd(
Returns:
Tensor: The input tensor after applying RoPE
"""
mscale = getattr(self.config, 'attention_scaling', 1.0)
mscale = self.config.attention_scaling
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing getattr(..., 1.0) makes this function fragile. If self.config is a standard TransformerConfig instance (which does not have attention_scaling), this will raise an AttributeError. Additionally, this patched function continues to ignore the mscale argument passed to it, which is what necessitates the hacky config mutation in gemma4.py. It is recommended to use the passed mscale argument and fallback to the config value only if it is at its default value.

Suggested change
mscale = self.config.attention_scaling
mscale = getattr(self.config, 'attention_scaling', 1.0) if mscale == 1.0 else mscale

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of getattr(..., 1.0) reduces the robustness of this module. If self.config is a standard TransformerConfig that lacks the attention_scaling attribute, this change will cause an AttributeError. It is recommended to retain the getattr call with a default value to ensure compatibility with various configuration objects.

Suggested change
mscale = self.config.attention_scaling
mscale = getattr(self.config, 'attention_scaling', 1.0)

rot_dim = freqs.shape[-1]

# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
Expand Down
Loading