Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 179 additions & 3 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,165 @@ def fake_int4_quantization_ste(x, group_size):

return x_out

# ------------------------------------------------------------------
# MXFP4 fake-QAT. Matches the FlashInfer MXFP4 rollout kernel:
# per-block (1 x 32) absmax -> E8M0 power-of-2 scale -> round to E2M1 grid
# -> dequant multiply. Straight-through backward. Gated by
# OPEN_TRAINING_MXFP4_FAKE_QAT_FLAG; block size defaults to 32 and is
# overridable via OPEN_TRAINING_MXFP4_BLOCK_SIZE for ablations. Mutually
# exclusive with OPEN_TRAINING_INT4_FAKE_QAT_FLAG.
#
# Bit-exact equivalence to the rollout-path quantization has been verified
# on real DSV4-Flash expert weights against flashinfer.fp4_quantize +
# flashinfer.mxfp4_dequantize_host (see
# tools/infra/mxfp4_ste_vs_flashinfer_probe_ckpt.py).
# ------------------------------------------------------------------
_MXFP4_E2M1_POS_GRID = (0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0)
_MXFP4_E2M1_MAX = 6.0
_MXFP4_BLOCK_SIZE = 32

# Module-level cached constants for E2M1 rounding. These are tiny
# (7-elt boundaries + 8-elt grid) but creating them per-call allocates
# temporaries 128+ times per forward when we quantize every expert weight.
_MXFP4_E2M1_BOUNDARIES_BF16 = None
_MXFP4_E2M1_GRID_BF16 = None

def _get_e2m1_tables(device, dtype):
global _MXFP4_E2M1_BOUNDARIES_BF16, _MXFP4_E2M1_GRID_BF16
if (
_MXFP4_E2M1_BOUNDARIES_BF16 is None
or _MXFP4_E2M1_BOUNDARIES_BF16.device != device
or _MXFP4_E2M1_BOUNDARIES_BF16.dtype != dtype
):
_MXFP4_E2M1_BOUNDARIES_BF16 = torch.tensor(
[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
dtype=dtype, device=device,
)
_MXFP4_E2M1_GRID_BF16 = torch.tensor(
_MXFP4_E2M1_POS_GRID, dtype=dtype, device=device,
)
return _MXFP4_E2M1_BOUNDARIES_BF16, _MXFP4_E2M1_GRID_BF16

def _round_to_e2m1(x: torch.Tensor) -> torch.Tensor:
"""Round to E2M1 grid.

Keeps the bit-exact semantics of the miles reference (round-half-up via
bucketize with right=False) but minimizes peak memory:
- boundaries/grid tables cached at module level (no per-call alloc)
- bucket index is int32 (torch.bucketize output is int64 by default
but we pass out_int32=True)
- sign / grid-lookup / multiply tensors are reused where possible

Called on x that is already sign-preserved fp32 (or bf16) and clamped
to [-E2M1_MAX, E2M1_MAX].
"""
boundaries, grid = _get_e2m1_tables(x.device, x.dtype)
mag = x.abs()
# int32 bucket index, half the memory of the default int64.
idx = torch.bucketize(mag, boundaries, right=False, out_int32=True)
idx.clamp_(max=len(_MXFP4_E2M1_POS_GRID) - 1)
grid_vals = grid[idx] # same dtype as x
del idx
# Restore sign via torch.copysign (single tensor output, same dtype).
return torch.copysign(grid_vals, x)

class _FakeMXFP4QuantizationSTE(torch.autograd.Function):
"""Fake MXFP4 quantization for QAT, shape [M, N] with N as quant axis.

Backward is straight-through, so nothing from the forward math needs to
survive to backward. Run the entire forward under ``torch.no_grad()`` so
the autograd engine does not hold references to intermediates.

Memory strategy (the forward is called for every expert weight every
forward pass; DSV4 has up to 128 experts x 2 grouped linears x 4 layers
simultaneously live inside one list comprehension, so per-call peak
matters a lot):
- skip padding when shape is already aligned (DSV4 MoE is aligned)
- do the absmax in fp32 (small: numel/32) but keep the bulk division
in the original dtype (bf16), not fp32 -- cuts per-call peak 2x
- round to E2M1 via cached tables + int32 bucket index (see
_round_to_e2m1)
- release scale / index / sign tensors as soon as possible
"""

@staticmethod
def forward(ctx, x, block_size):
with torch.no_grad():
m, n = x.shape
block_size_m, block_size_n = 1, block_size

# Fast path: no padding required -> view x directly. DSV4 MoE
# weights (4096, 4096) and (4096, 2048) always hit this.
if m % block_size_m == 0 and n % block_size_n == 0:
x_view = x.view(
m // block_size_m,
block_size_m,
n // block_size_n,
block_size_n,
)
needs_unpad = False
else:
m_padded = ceil_div(m, block_size_m) * block_size_m
n_padded = ceil_div(n, block_size_n) * block_size_n
x_padded = torch.zeros(
(m_padded, n_padded),
dtype=x.dtype, device=x.device,
)
x_padded[:m, :n] = x
x_view = x_padded.view(
m_padded // block_size_m,
block_size_m,
n_padded // block_size_n,
block_size_n,
)
needs_unpad = True

# Per-block absmax in fp32 for numerical stability. Tensor is
# 32x smaller than the weight itself, so cost is negligible.
x_max = x_view.abs().amax(dim=(1, 3), keepdim=True).float()
# scale = ceil(log2(max / E2M1_MAX)), encoded as E8M0 power-of-2.
x_max.div_(_MXFP4_E2M1_MAX).clamp_(min=1e-8).log2_().ceil_().clamp_(
min=-127.0, max=127.0
)
# x_max now holds the log2 exponent; exponentiate into x_scale.
x_scale = torch.pow(2.0, x_max).to(x.dtype)
del x_max

# Divide / clamp / round in the input dtype (bf16), not fp32.
# Peak memory here is 1x the weight (vs 2x if we went to fp32).
x_div = (x_view / x_scale).clamp_(-_MXFP4_E2M1_MAX, _MXFP4_E2M1_MAX)
x_q = _round_to_e2m1(x_div)
del x_div

x_dequant = x_q.mul_(x_scale)
del x_q, x_scale

if needs_unpad:
x_dequant_full = x_dequant.reshape(
x_view.size(0) * block_size_m,
x_view.size(2) * block_size_n,
)
x_out = x_dequant_full[:m, :n].contiguous()
else:
x_out = x_dequant.view(m, n)

return x_out

@staticmethod
def backward(ctx, grad_output):
return grad_output, None

def fake_mxfp4_quantization_ste(x, block_size=_MXFP4_BLOCK_SIZE):
x_out = _FakeMXFP4QuantizationSTE.apply(x, block_size)

# Preserve Megatron DDP's ``main_grad`` accumulator: the outer
# optimizer reduces into ``param.main_grad``, so callers that look up
# ``main_grad`` on the returned tensor must still find it.
if hasattr(x, 'main_grad'):
x_out.main_grad = x.main_grad

return x_out

class TEGroupedLinear(te.pytorch.GroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Expand Down Expand Up @@ -1434,14 +1593,31 @@ def _get_weight_tensors(self):
"""Get the weight tensors of the module."""
weight_tensors = super()._get_weight_tensors()

if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1":
int4_enabled = os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1"
mxfp4_enabled = os.getenv("OPEN_TRAINING_MXFP4_FAKE_QAT_FLAG", "0") == "1"

if int4_enabled and mxfp4_enabled:
raise RuntimeError(
"INT4 and MXFP4 fake QAT are mutually exclusive; set exactly "
"one of OPEN_TRAINING_INT4_FAKE_QAT_FLAG / "
"OPEN_TRAINING_MXFP4_FAKE_QAT_FLAG to 1."
)

if int4_enabled:
group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128"))

weight_tensors = [
fake_int4_quantization_ste(w, group_size)
fake_int4_quantization_ste(w, group_size)
for w in weight_tensors
]

elif mxfp4_enabled:
# MXFP4 spec fixes block size at 32; env is overridable for ablations.
block_size = int(os.getenv("OPEN_TRAINING_MXFP4_BLOCK_SIZE", "32"))
weight_tensors = [
fake_mxfp4_quantization_ste(w, block_size)
for w in weight_tensors
]

return weight_tensors

def _encode_extra_state(self, state):
Expand Down