Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@
multi_tensor_scale_torch,
multi_tensor_l2norm_torch,
multi_tensor_adam_torch,
multi_tensor_adam_fp8_torch,
multi_tensor_adam_capturable_torch,
multi_tensor_adam_capturable_master_torch,
multi_tensor_adam_param_remainder_torch,
multi_tensor_sgd_torch,
multi_tensor_compute_scale_and_scale_inv_torch,
multi_tensor_compute_scale_inv_e8m0_torch,
)

__all__ = [
Expand Down Expand Up @@ -105,7 +109,11 @@
"multi_tensor_scale_torch",
"multi_tensor_l2norm_torch",
"multi_tensor_adam_torch",
"multi_tensor_adam_fp8_torch",
"multi_tensor_adam_capturable_torch",
"multi_tensor_adam_capturable_master_torch",
"multi_tensor_adam_param_remainder_torch",
"multi_tensor_sgd_torch",
"multi_tensor_compute_scale_and_scale_inv_torch",
"multi_tensor_compute_scale_inv_e8m0_torch",
]
181 changes: 181 additions & 0 deletions transformer_engine/plugin/core/backends/reference/impl/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
"multi_tensor_scale_torch",
"multi_tensor_l2norm_torch",
"multi_tensor_adam_torch",
"multi_tensor_adam_fp8_torch",
"multi_tensor_adam_capturable_torch",
"multi_tensor_adam_capturable_master_torch",
"multi_tensor_adam_param_remainder_torch",
"multi_tensor_sgd_torch",
"multi_tensor_compute_scale_and_scale_inv_torch",
"multi_tensor_compute_scale_inv_e8m0_torch",
]


Expand Down Expand Up @@ -392,3 +396,180 @@ def multi_tensor_compute_scale_and_scale_inv_torch(
# Update scale and scale_inv
scale.copy_(computed_scale)
scale_inv.copy_(1.0 / computed_scale)


def multi_tensor_compute_scale_inv_e8m0_torch(
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
block_len: int,
) -> None:
"""
Compute scale_inv in E8M0 format from amax values for MXFP8 quantization.

Args:
chunk_size: Chunk size (unused in PyTorch implementation)
noop_flag: If non-zero, skip computation
tensor_lists: [amaxes, scale_invs]
block_len: Block length for block-wise scaling
"""
if noop_flag is not None and noop_flag.item() != 0:
return

if len(tensor_lists) != 2:
raise ValueError("tensor_lists should contain [amaxes, scale_invs]")

amaxes, scale_invs = tensor_lists

if len(amaxes) != len(scale_invs):
raise ValueError("All tensor lists must have the same length")

for amax, scale_inv in zip(amaxes, scale_invs):
amax_val = torch.clamp(amax, min=2**-127)
# E8M0: biased exponent = floor(log2(amax)) + 127
log2_amax = torch.floor(torch.log2(amax_val))
biased_exp = (log2_amax + 127).to(torch.uint8)
scale_inv.copy_(biased_exp)


def multi_tensor_adam_fp8_torch(
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
lr: float,
beta1: float,
beta2: float,
epsilon: float,
step: int,
mode: int,
bias_correction: int,
weight_decay: float,
fp8_dtype,
) -> None:
"""
FP8 adam optimizer - reference backend fallback.

Note: This is a fallback implementation that uses FP32 computation instead of FP8.
FP8 training is a GPU-specific feature and not supported in the reference backend.
"""
if fp8_dtype is not None:
raise NotImplementedError(
"FP8 adam is not supported in the reference backend. "
"FP8 training requires GPU acceleration. "
"Please use a CUDA-enabled build or disable FP8 optimization."
)

# Fallback to regular adam with FP32 computation
multi_tensor_adam_torch(
chunk_size,
noop_flag,
tensor_lists,
lr,
beta1,
beta2,
epsilon,
step,
mode,
bias_correction,
weight_decay,
)


def multi_tensor_adam_capturable_torch(
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
lr: torch.Tensor,
beta1: float,
beta2: float,
epsilon: float,
step: torch.Tensor,
mode: int,
bias_correction: int,
weight_decay: float,
inv_scale: torch.Tensor,
) -> None:
"""
Capturable adam optimizer - reference backend fallback.

Note: This is a fallback implementation that does not support CUDA graph capture.
CUDA graph capture is a GPU-specific feature and not supported in the reference backend.
"""
if isinstance(lr, torch.Tensor) and lr.requires_grad:
raise NotImplementedError(
"Capturable adam with tensor lr is not supported in the reference backend. "
"CUDA graph capture requires GPU acceleration. "
"Please use a CUDA-enabled build or use scalar lr."
)

if isinstance(step, torch.Tensor) and step.requires_grad:
raise NotImplementedError(
"Capturable adam with tensor step is not supported in the reference backend. "
"CUDA graph capture requires GPU acceleration. "
"Please use a CUDA-enabled build or use scalar step."
)

# Fallback to regular adam with scalar parameters
multi_tensor_adam_torch(
chunk_size,
noop_flag,
tensor_lists,
lr.item() if isinstance(lr, torch.Tensor) else lr,
beta1,
beta2,
epsilon,
step.item() if isinstance(step, torch.Tensor) else step,
mode,
bias_correction,
weight_decay,
)


def multi_tensor_adam_capturable_master_torch(
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
lr: torch.Tensor,
beta1: float,
beta2: float,
epsilon: float,
step: torch.Tensor,
mode: int,
bias_correction: int,
weight_decay: float,
inv_scale: torch.Tensor,
) -> None:
"""
Capturable master adam optimizer - reference backend fallback.

Note: This is a fallback implementation that does not support CUDA graph capture
or master weight management. These are GPU-specific features.
"""
if isinstance(lr, torch.Tensor) and lr.requires_grad:
raise NotImplementedError(
"Capturable master adam with tensor lr is not supported in the reference backend. "
"CUDA graph capture requires GPU acceleration. "
"Please use a CUDA-enabled build or use scalar lr."
)

if isinstance(step, torch.Tensor) and step.requires_grad:
raise NotImplementedError(
"Capturable master adam with tensor step is not supported in the reference backend. "
"CUDA graph capture requires GPU acceleration. "
"Please use a CUDA-enabled build or use scalar step."
)

# Fallback to regular adam with scalar parameters
multi_tensor_adam_torch(
chunk_size,
noop_flag,
tensor_lists,
lr.item() if isinstance(lr, torch.Tensor) else lr,
beta1,
beta2,
epsilon,
step.item() if isinstance(step, torch.Tensor) else step,
mode,
bias_correction,
weight_decay,
)
127 changes: 127 additions & 0 deletions transformer_engine/plugin/core/backends/reference/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@
multi_tensor_scale_torch,
multi_tensor_l2norm_torch,
multi_tensor_adam_torch,
multi_tensor_adam_fp8_torch,
multi_tensor_adam_capturable_torch,
multi_tensor_adam_capturable_master_torch,
multi_tensor_adam_param_remainder_torch,
multi_tensor_sgd_torch,
multi_tensor_compute_scale_and_scale_inv_torch,
multi_tensor_compute_scale_inv_e8m0_torch,
)


Expand Down Expand Up @@ -557,6 +562,96 @@ def multi_tensor_adam(
weight_decay,
)

def multi_tensor_adam_fp8(
self,
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
lr: float,
beta1: float,
beta2: float,
epsilon: float,
step: int,
mode: int,
bias_correction: int,
weight_decay: float,
fp8_dtype,
) -> None:
return multi_tensor_adam_fp8_torch(
chunk_size,
noop_flag,
tensor_lists,
lr,
beta1,
beta2,
epsilon,
step,
mode,
bias_correction,
weight_decay,
fp8_dtype,
)

def multi_tensor_adam_capturable(
self,
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
lr: torch.Tensor,
beta1: float,
beta2: float,
epsilon: float,
step: torch.Tensor,
mode: int,
bias_correction: int,
weight_decay: float,
inv_scale: torch.Tensor,
) -> None:
return multi_tensor_adam_capturable_torch(
chunk_size,
noop_flag,
tensor_lists,
lr,
beta1,
beta2,
epsilon,
step,
mode,
bias_correction,
weight_decay,
inv_scale,
)

def multi_tensor_adam_capturable_master(
self,
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
lr: torch.Tensor,
beta1: float,
beta2: float,
epsilon: float,
step: torch.Tensor,
mode: int,
bias_correction: int,
weight_decay: float,
inv_scale: torch.Tensor,
) -> None:
return multi_tensor_adam_capturable_master_torch(
chunk_size,
noop_flag,
tensor_lists,
lr,
beta1,
beta2,
epsilon,
step,
mode,
bias_correction,
weight_decay,
inv_scale,
)

def multi_tensor_adam_param_remainder(
self,
chunk_size: int,
Expand Down Expand Up @@ -613,6 +708,38 @@ def multi_tensor_sgd(
scale,
)

def multi_tensor_compute_scale_and_scale_inv(
self,
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
max_fp8: float,
force_pow_2_scales: bool,
epsilon: float,
) -> None:
return multi_tensor_compute_scale_and_scale_inv_torch(
chunk_size,
noop_flag,
tensor_lists,
max_fp8,
force_pow_2_scales,
epsilon,
)

def multi_tensor_compute_scale_inv_e8m0(
self,
chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
block_len: int,
) -> None:
return multi_tensor_compute_scale_inv_e8m0_torch(
chunk_size,
noop_flag,
tensor_lists,
block_len,
)

def get_flash_attention_class(self):
from .flash_attention import FlashAttentionTorch

Expand Down
Loading
Loading