diff --git a/transformer_engine/plugin/core/backends/reference/impl/__init__.py b/transformer_engine/plugin/core/backends/reference/impl/__init__.py index f467767d61..deee9905ff 100644 --- a/transformer_engine/plugin/core/backends/reference/impl/__init__.py +++ b/transformer_engine/plugin/core/backends/reference/impl/__init__.py @@ -54,9 +54,13 @@ multi_tensor_scale_torch, multi_tensor_l2norm_torch, multi_tensor_adam_torch, + multi_tensor_adam_fp8_torch, + multi_tensor_adam_capturable_torch, + multi_tensor_adam_capturable_master_torch, multi_tensor_adam_param_remainder_torch, multi_tensor_sgd_torch, multi_tensor_compute_scale_and_scale_inv_torch, + multi_tensor_compute_scale_inv_e8m0_torch, ) __all__ = [ @@ -105,7 +109,11 @@ "multi_tensor_scale_torch", "multi_tensor_l2norm_torch", "multi_tensor_adam_torch", + "multi_tensor_adam_fp8_torch", + "multi_tensor_adam_capturable_torch", + "multi_tensor_adam_capturable_master_torch", "multi_tensor_adam_param_remainder_torch", "multi_tensor_sgd_torch", "multi_tensor_compute_scale_and_scale_inv_torch", + "multi_tensor_compute_scale_inv_e8m0_torch", ] diff --git a/transformer_engine/plugin/core/backends/reference/impl/optimizer.py b/transformer_engine/plugin/core/backends/reference/impl/optimizer.py index 890ae9a563..15bb877979 100644 --- a/transformer_engine/plugin/core/backends/reference/impl/optimizer.py +++ b/transformer_engine/plugin/core/backends/reference/impl/optimizer.py @@ -9,9 +9,13 @@ "multi_tensor_scale_torch", "multi_tensor_l2norm_torch", "multi_tensor_adam_torch", + "multi_tensor_adam_fp8_torch", + "multi_tensor_adam_capturable_torch", + "multi_tensor_adam_capturable_master_torch", "multi_tensor_adam_param_remainder_torch", "multi_tensor_sgd_torch", "multi_tensor_compute_scale_and_scale_inv_torch", + "multi_tensor_compute_scale_inv_e8m0_torch", ] @@ -392,3 +396,180 @@ def multi_tensor_compute_scale_and_scale_inv_torch( # Update scale and scale_inv scale.copy_(computed_scale) scale_inv.copy_(1.0 / computed_scale) + + +def multi_tensor_compute_scale_inv_e8m0_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + block_len: int, +) -> None: + """ + Compute scale_inv in E8M0 format from amax values for MXFP8 quantization. + + Args: + chunk_size: Chunk size (unused in PyTorch implementation) + noop_flag: If non-zero, skip computation + tensor_lists: [amaxes, scale_invs] + block_len: Block length for block-wise scaling + """ + if noop_flag is not None and noop_flag.item() != 0: + return + + if len(tensor_lists) != 2: + raise ValueError("tensor_lists should contain [amaxes, scale_invs]") + + amaxes, scale_invs = tensor_lists + + if len(amaxes) != len(scale_invs): + raise ValueError("All tensor lists must have the same length") + + for amax, scale_inv in zip(amaxes, scale_invs): + amax_val = torch.clamp(amax, min=2**-127) + # E8M0: biased exponent = floor(log2(amax)) + 127 + log2_amax = torch.floor(torch.log2(amax_val)) + biased_exp = (log2_amax + 127).to(torch.uint8) + scale_inv.copy_(biased_exp) + + +def multi_tensor_adam_fp8_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype, +) -> None: + """ + FP8 adam optimizer - reference backend fallback. + + Note: This is a fallback implementation that uses FP32 computation instead of FP8. + FP8 training is a GPU-specific feature and not supported in the reference backend. + """ + if fp8_dtype is not None: + raise NotImplementedError( + "FP8 adam is not supported in the reference backend. " + "FP8 training requires GPU acceleration. " + "Please use a CUDA-enabled build or disable FP8 optimization." + ) + + # Fallback to regular adam with FP32 computation + multi_tensor_adam_torch( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + +def multi_tensor_adam_capturable_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, +) -> None: + """ + Capturable adam optimizer - reference backend fallback. + + Note: This is a fallback implementation that does not support CUDA graph capture. + CUDA graph capture is a GPU-specific feature and not supported in the reference backend. + """ + if isinstance(lr, torch.Tensor) and lr.requires_grad: + raise NotImplementedError( + "Capturable adam with tensor lr is not supported in the reference backend. " + "CUDA graph capture requires GPU acceleration. " + "Please use a CUDA-enabled build or use scalar lr." + ) + + if isinstance(step, torch.Tensor) and step.requires_grad: + raise NotImplementedError( + "Capturable adam with tensor step is not supported in the reference backend. " + "CUDA graph capture requires GPU acceleration. " + "Please use a CUDA-enabled build or use scalar step." + ) + + # Fallback to regular adam with scalar parameters + multi_tensor_adam_torch( + chunk_size, + noop_flag, + tensor_lists, + lr.item() if isinstance(lr, torch.Tensor) else lr, + beta1, + beta2, + epsilon, + step.item() if isinstance(step, torch.Tensor) else step, + mode, + bias_correction, + weight_decay, + ) + + +def multi_tensor_adam_capturable_master_torch( + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, +) -> None: + """ + Capturable master adam optimizer - reference backend fallback. + + Note: This is a fallback implementation that does not support CUDA graph capture + or master weight management. These are GPU-specific features. + """ + if isinstance(lr, torch.Tensor) and lr.requires_grad: + raise NotImplementedError( + "Capturable master adam with tensor lr is not supported in the reference backend. " + "CUDA graph capture requires GPU acceleration. " + "Please use a CUDA-enabled build or use scalar lr." + ) + + if isinstance(step, torch.Tensor) and step.requires_grad: + raise NotImplementedError( + "Capturable master adam with tensor step is not supported in the reference backend. " + "CUDA graph capture requires GPU acceleration. " + "Please use a CUDA-enabled build or use scalar step." + ) + + # Fallback to regular adam with scalar parameters + multi_tensor_adam_torch( + chunk_size, + noop_flag, + tensor_lists, + lr.item() if isinstance(lr, torch.Tensor) else lr, + beta1, + beta2, + epsilon, + step.item() if isinstance(step, torch.Tensor) else step, + mode, + bias_correction, + weight_decay, + ) diff --git a/transformer_engine/plugin/core/backends/reference/reference.py b/transformer_engine/plugin/core/backends/reference/reference.py index b6b45342f4..5c77701f4c 100644 --- a/transformer_engine/plugin/core/backends/reference/reference.py +++ b/transformer_engine/plugin/core/backends/reference/reference.py @@ -53,8 +53,13 @@ multi_tensor_scale_torch, multi_tensor_l2norm_torch, multi_tensor_adam_torch, + multi_tensor_adam_fp8_torch, + multi_tensor_adam_capturable_torch, + multi_tensor_adam_capturable_master_torch, multi_tensor_adam_param_remainder_torch, multi_tensor_sgd_torch, + multi_tensor_compute_scale_and_scale_inv_torch, + multi_tensor_compute_scale_inv_e8m0_torch, ) @@ -557,6 +562,96 @@ def multi_tensor_adam( weight_decay, ) + def multi_tensor_adam_fp8( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype, + ) -> None: + return multi_tensor_adam_fp8_torch( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + fp8_dtype, + ) + + def multi_tensor_adam_capturable( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + return multi_tensor_adam_capturable_torch( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_adam_capturable_master( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + return multi_tensor_adam_capturable_master_torch( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + def multi_tensor_adam_param_remainder( self, chunk_size: int, @@ -613,6 +708,38 @@ def multi_tensor_sgd( scale, ) + def multi_tensor_compute_scale_and_scale_inv( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool, + epsilon: float, + ) -> None: + return multi_tensor_compute_scale_and_scale_inv_torch( + chunk_size, + noop_flag, + tensor_lists, + max_fp8, + force_pow_2_scales, + epsilon, + ) + + def multi_tensor_compute_scale_inv_e8m0( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + block_len: int, + ) -> None: + return multi_tensor_compute_scale_inv_e8m0_torch( + chunk_size, + noop_flag, + tensor_lists, + block_len, + ) + def get_flash_attention_class(self): from .flash_attention import FlashAttentionTorch diff --git a/transformer_engine/plugin/core/backends/reference/register_ops.py b/transformer_engine/plugin/core/backends/reference/register_ops.py index 9d66e24056..b5dba96d87 100644 --- a/transformer_engine/plugin/core/backends/reference/register_ops.py +++ b/transformer_engine/plugin/core/backends/reference/register_ops.py @@ -468,6 +468,30 @@ def register_builtins(registry) -> None: vendor=None, priority=50, ), + OpImpl( + op_name="multi_tensor_adam_fp8", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_adam_fp8, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="multi_tensor_adam_capturable", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_adam_capturable, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="multi_tensor_adam_capturable_master", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_adam_capturable_master, is_avail), + vendor=None, + priority=50, + ), OpImpl( op_name="multi_tensor_sgd", impl_id="reference.torch", @@ -476,6 +500,22 @@ def register_builtins(registry) -> None: vendor=None, priority=50, ), + OpImpl( + op_name="multi_tensor_compute_scale_and_scale_inv", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_compute_scale_and_scale_inv, is_avail), + vendor=None, + priority=50, + ), + OpImpl( + op_name="multi_tensor_compute_scale_inv_e8m0", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.multi_tensor_compute_scale_inv_e8m0, is_avail), + vendor=None, + priority=50, + ), # FlashAttention class getter OpImpl( op_name="get_flash_attention_class", diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 06ef799ee1..1dbf51b185 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -10,7 +10,12 @@ import torch -from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_external_gemm +from transformer_engine_torch import CommOverlapType + +try: + from transformer_engine_torch import bulk_overlap_ag_with_external_gemm +except ImportError: + bulk_overlap_ag_with_external_gemm = None from transformer_engine import te_device_type