Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/core/dist_checkpointing/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def load_common(self, checkpoint_dir: Union[str, Path]):
msc = MultiStorageClientFeature.import_package()
return msc.torch.load(load_path, map_location='cpu')
else:
return torch.load(load_path, map_location='cpu')
return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
if MultiStorageClientFeature.is_enabled():
Expand Down
13 changes: 8 additions & 5 deletions megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,12 @@ def __init__(
def _validate_global_shapes(self, metadata, sharded_tensors):
for sh_ten in sharded_tensors:
if sh_ten.key not in metadata.state_dict_metadata:
raise KeyError(
f"{sh_ten.key} from model not in state dict:"
f" {sorted(metadata.state_dict_metadata.keys())}"
)
# raise KeyError(
# f"{sh_ten.key} from model not in state dict:"
# f" {sorted(metadata.state_dict_metadata.keys())}"
# )
print(f"{sh_ten.key} from model not in state dict, will skip")
continue
loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
expected_shape = sh_ten.global_shape
if loaded_shape != expected_shape:
Expand All @@ -530,7 +532,7 @@ def _temporarily_bypass_shape_validation(self):
tensor_metadata = self.metadata.state_dict_metadata
metadata_with_sizes = [
(tensor_metadata[key], tensor_metadata[key].size, sharded_tensor)
for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items()
for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata
]
try:
# Temporarily set sizes to expected shapes
Expand Down Expand Up @@ -802,6 +804,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors,
allow_partial_load=True,
),
)

Expand Down
70 changes: 70 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ def __init__(
self.te_quant_params: Optional[TEQuantizationParams] = None

for param in self.parameters():
setattr(param, "parallel_mode", parallel_mode)
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, "allreduce", not self.expert_parallel)
Expand Down Expand Up @@ -1455,6 +1456,61 @@ def sharded_state_dict(


if HAVE_TE and is_te_min_version("1.9.0.dev0"):
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y

class _FakeInt4QuantizationSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x, group_size):
m, n = x.shape
block_size_m, block_size_n = 1, group_size


m_padded = ceil_div(m, block_size_m) * block_size_m
n_padded = ceil_div(n, block_size_n) * block_size_n

x_padded = torch.zeros(
(m_padded, n_padded),
dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x

x_view = x_padded.view(
m_padded // block_size_m,
block_size_m,
n_padded // block_size_n,
block_size_n
)

x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True)
q_max = 7
x_scale = x_max / q_max

x_scale = x_scale.clamp(min=1e-5)

x_div = x_view / x_scale
x_round = torch.round(x_div)

x_q_clamped = x_round.clamp(-q_max, q_max)

x_dequant_view = x_q_clamped * x_scale

x_dequant_full = x_dequant_view.view_as(x_padded)
x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype)

return x_out

@staticmethod
def backward(ctx, grad_output):
return grad_output, None

def fake_int4_quantization_ste(x, group_size):
x_out = _FakeInt4QuantizationSTE.apply(x, group_size)

if hasattr(x, 'main_grad'):
x_out.main_grad = x.main_grad

return x_out

class TEGroupedLinear(te.pytorch.GroupedLinear):
"""
Expand Down Expand Up @@ -1671,6 +1727,20 @@ def forward(self, x, m_splits):
return out
return out, None

def _get_weight_tensors(self):
"""Get the weight tensors of the module."""
weight_tensors = super()._get_weight_tensors()

if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1":
group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128"))

weight_tensors = [
fake_int4_quantization_ste(w, group_size)
for w in weight_tensors
]

return weight_tensors

def _encode_extra_state(self, state):
# TE 2.0 changed the format of extra_state to be a byte tensor
if is_te_min_version("2.0.0"):
Expand Down
83 changes: 51 additions & 32 deletions megatron/core/fusions/fused_mla_yarn_rope_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel(
SIN,
emb_dim: tl.constexpr,
k_dim: tl.constexpr,
k_dim_ceil: tl.constexpr,
v_dim: tl.constexpr,
head_num: tl.constexpr,
batch_size,
Expand Down Expand Up @@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel(
cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2))
sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2))

KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads
kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads
mask = kv_off < head_num * stride_kv_nheads
k_in_off = kv_off + tl.arange(0, k_dim)[None, :]
v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :]
k = tl.load(KV_ptr + k_in_off, mask=mask)
v = tl.load(KV_ptr + v_in_off, mask=mask)
KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads
ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
kj_range = tl.arange(0, k_dim_ceil)[None, :]
mask_k = (ki_range < head_num) & (kj_range < k_dim)
mask_v = ki_range < head_num
k_off = ki_range * stride_kv_nheads + kj_range
if v_dim > 0:
v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :]
v = tl.load(KV_ptr + v_off, mask=mask_v)
else:
v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty)
k = tl.load(KV_ptr + k_off, mask=mask_k)

K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads
V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads
K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads
V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads

k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :]
v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :]
tl.store(K_ptr + k_out_off, k, mask=mask)
tl.store(V_ptr + v_out_off, v, mask=mask)
k_out_off = ki_range * stride_k_nheads + kj_range
tl.store(K_ptr + k_out_off, k, mask=mask_k)
if v_dim > 0:
v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :]
tl.store(V_ptr + v_out_off, v, mask=mask_v)

EMB = K_POS_EMB + pid_m * stride_emb_seq
# x1 = t[..., 0::2], x2 = t[..., 1::2]
Expand All @@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel(
x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2)
x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2)

x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
mask_x = x_range < head_num
x_left_off = (
tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads
x_range * stride_k_nheads
+ k_dim
+ tl.arange(0, emb_dim // 2)[None, :]
)
x_right_off = x_left_off + emb_dim // 2
tl.store(K_ptr + x_left_off, x_left, mask=mask)
tl.store(K_ptr + x_right_off, x_right, mask=mask)
tl.store(K_ptr + x_left_off, x_left, mask=mask_x)
tl.store(K_ptr + x_right_off, x_right, mask=mask_x)


@triton.autotune(
Expand All @@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel(
SIN,
emb_dim: tl.constexpr,
k_dim: tl.constexpr,
k_dim_ceil: tl.constexpr,
v_dim: tl.constexpr,
head_num: tl.constexpr,
batch_size,
Expand Down Expand Up @@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel(
else:
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)

dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads
dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads
mask = dkv_off < head_num * stride_dkv_nheads
dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :]
dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :]

dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads
dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads
dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :]
dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :]
dk = tl.load(dK_ptr + dk_in_off, mask=mask)
dv = tl.load(dV_ptr + dv_in_off, mask=mask)
tl.store(dKV_ptr + dk_out_off, dk, mask=mask)
tl.store(dKV_ptr + dv_out_off, dv, mask=mask)
dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads
ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
kj_range = tl.arange(0, k_dim_ceil)[None, :]
mask_k = (ki_range < head_num) & (kj_range < k_dim)
mask_v = ki_range < head_num
dk_out_off = ki_range * stride_dkv_nheads + kj_range

dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads
dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads
dk_in_off = ki_range * stride_dk_nheads + kj_range

dk = tl.load(dK_ptr + dk_in_off, mask=mask_k)
tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k)

if v_dim > 0:
dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :]
dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :]
dv = tl.load(dV_ptr + dv_in_off, mask=mask_v)
tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v)

if pid_head == 0:
x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32)
x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32)
for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)):
dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads
x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim
dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads
x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads
mask = x_off < head_num * stride_dk_nheads
x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :]
x_right_off = x_left_off + emb_dim // 2
Expand Down Expand Up @@ -632,6 +647,7 @@ def forward(

o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim)
o_value = kv.new_empty(total_seqlen, nheads, v_dim)
k_dim_ceil = triton.next_power_of_2(k_dim)

grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"]))
rotary_fwd_kv_kernel[grid](
Expand All @@ -643,6 +659,7 @@ def forward(
sin,
emb_dim,
k_dim,
k_dim_ceil,
v_dim,
nheads,
batch_size,
Expand Down Expand Up @@ -700,6 +717,7 @@ def backward(ctx, dk, dv):

d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim)
d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim)
k_dim_ceil = triton.next_power_of_2(ctx.k_dim)

grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"]))
rotary_bwd_kv_kernel[grid](
Expand All @@ -711,6 +729,7 @@ def backward(ctx, dk, dv):
sin,
ctx.emb_dim,
ctx.k_dim,
k_dim_ceil,
ctx.v_dim,
nheads,
batch_size,
Expand Down
76 changes: 75 additions & 1 deletion megatron/core/models/common/language_module/language_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import logging
import os
from typing import Optional, Tuple
from typing import Any, Dict, Literal, Optional, Tuple

import torch
from torch import Tensor
Expand All @@ -14,6 +14,10 @@
except:
te_parallel_cross_entropy = None
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
try:
from cut_cross_entropy import linear_cross_entropy
except ImportError:
linear_cross_entropy = None
from megatron.core.pipeline_parallel.utils import (
is_pp_first_stage,
is_pp_last_stage,
Expand Down Expand Up @@ -125,6 +129,76 @@ def check_and_set_env_variable(
check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.auto)
check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.auto)

def compute_output_layer_and_language_model_loss(
self,
hidden: Tensor,
labels: Optional[Tensor],
weight: Tensor = None,
sequence_parallel_enabled: bool = False,
column_parallel_linear: torch.nn.Module = None,
col_linear_kwargs: Dict[str, Any] = {},
reduction: Literal["none", "sum", "mean"] = "none",
ignore_index: int = -100,
) -> Tensor:
"""Computes the language model logits and loss (Cross entropy across vocabulary)

Args:
hidden (Tensor): The hidden states from the transformer model
labels (Optional[Tensor]): The labels of dimension [batch size, seq length]
weight (Tensor): The weight tensor of shape [vocab size, hidden size].
Required if using fused linear cross entropy.
column_parallel_linear (torch.nn.Module): The column parallel linear
layer to use for computing logits when not using fused linear cross entropy.
col_linear_kwargs (Dict[str, Any]): Additional kwargs for column parallel linear layer
reduction (Optional[str]): The reduction method. Defaults to "none", and can be
one of "none", "sum", "mean".
ignore_index (Optional[int]): The index to ignore in the loss calculation.
Defaults to -100.

Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length].
"""
if (
self.config.cross_entropy_loss_fusion
and self.config.cross_entropy_fusion_impl == 'linear'
):
assert (
weight is not None
), "weight cannot be None when using fused linear cross entropy."
assert (
labels is not None
), "labels cannot be None when using fused linear cross entropy."
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
loss = linear_cross_entropy(
hidden,
weight,
labels,
tp_group=self.pg_collection.tp,
sequence_parallel=sequence_parallel_enabled,
reduction=reduction,
ignore_index=ignore_index,
)

# [s b] => [b, s]
loss = loss.view_as(labels).transpose(0, 1).contiguous()
return loss
else:
assert (
column_parallel_linear is not None
), "column_parallel_linear cannot be None when not using fused linear cross entropy."
# output
output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()}
output_layer_buffers = dict(column_parallel_linear.named_buffers())
logits, _ = torch.func.functional_call(
column_parallel_linear,
{**output_layer_params, **output_layer_buffers},
(hidden,),
col_linear_kwargs,
)

return self.compute_language_model_loss(labels, logits)

def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)

Expand Down
Loading