Skip to content

Commit

Permalink
[PyTorch] Minor optimizations to reduce CPU overheads in modules (#1191)
Browse files Browse the repository at this point in the history
* CPU perf optimization in linear autograd function

Avoid enable_grad context when possible in cast function. Cache distributed group properties.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* CPU perf optimization in prepare_forward function

Avoid torch.nn.Module impl of __setattr__.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Avoid module import in TE module forwards

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use fast getter for params

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Reuse tensor dims in linear autograd func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply optimizations to grouped linear

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug test failures

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug test failures

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix linter warnings

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Avoid deepcopy in tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Move _fast_setattr logic to __setattr__ method

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
timmoon10 and pre-commit-ci[bot] authored Oct 4, 2024
1 parent 10cceae commit 9d976bc
Showing 13 changed files with 140 additions and 70 deletions.
10 changes: 8 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
@@ -1602,10 +1602,12 @@ def test_gpt_cuda_graph(dtype, bs, model):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

block = TransformerLayer(
block_args = (
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
)
block_kwargs = dict(
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
@@ -1617,7 +1619,11 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layernorm=False,
device="cuda",
)
graphed_block = copy.deepcopy(block)
block = TransformerLayer(*block_args, **block_kwargs)
graphed_block = TransformerLayer(*block_args, **block_kwargs)
with torch.no_grad():
for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
param2.copy_(param1)

out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,11 @@
CPUOffloadEnabled = False


def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled


class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from __future__ import annotations

from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings

@@ -125,13 +126,15 @@ def set_tensor_model_parallel_attributes(
setattr(tensor, "partition_stride", stride)


@lru_cache
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
"""Return world size for the distributed group."""
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size(group=group)


@lru_cache
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
"""Return my rank for the distributed group."""
assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)

@@ -119,7 +119,7 @@ def bgrad_dgelu_fused(
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)

41 changes: 35 additions & 6 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
import fcntl
import struct
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager

import torch
@@ -406,6 +406,36 @@ def __init__(self) -> None:
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
self.activation_dtype: Optional[torch.dtype] = None

# Fast getter for parameters
# Note: torch.nn.Module does not store parameters like normal
# attrs, but rather in a dict. When attempting to access, the
# module will raise an AttributeError in __getattribute__ and
# call a custom __getattr__. This is unnecessary overhead if
# we know we are accessing a parameter.
self._fast_get_param: Callable[str, torch.nn.Parameter]
self._fast_get_param = self.__dict__["_parameters"].get

# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
"activation_dtype",
"fp8",
"fp8_initialized",
"fp8_calibration",
"fp8_parameters",
}

def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)

def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
@@ -593,7 +623,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
return

# All checks after this have already been performed once, thus skip
if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
if self.activation_dtype == inp.dtype:
return

dtype = inp.dtype
@@ -708,10 +738,9 @@ def prepare_forward(
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)

with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
if not allow_non_contiguous:
yield inp.contiguous()
else:
yield inp
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp

if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
11 changes: 5 additions & 6 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
@@ -39,8 +39,9 @@
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..tensor import Float8Tensor, QuantizedTensor
from ..export import is_in_onnx_export_mode
from ..cpu_offload import is_cpu_offload_enabled

__all__ = ["GroupedLinear"]

@@ -715,11 +716,11 @@ def forward(

with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp:

weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
weight_tensors = [self._fast_get_param(f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
weight_tensors = [
w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
]

weight_tensors_fp8 = [None] * self.num_gemms
@@ -746,8 +747,6 @@ def forward(
skip_update_flag=skip_fp8_weight_update,
)

from ..cpu_offload import CPUOffloadEnabled

if torch.is_grad_enabled():
linear_fn = _GroupedLinear.apply
args = []
@@ -763,7 +762,7 @@ def forward(
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
self._offsets,
17 changes: 15 additions & 2 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
from torch.nn import init

import transformer_engine_torch as tex
from .base import TransformerEngineBaseModule
from ..cpp_extensions import (
layernorm_fwd_inf,
)
@@ -143,6 +142,7 @@ def __init__(
)
)
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None

self.reset_parameters(defer_init=(device == "meta"))

@@ -186,8 +186,21 @@ def reset_parameters(self, defer_init=False) -> None:
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""

# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype

if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
26 changes: 12 additions & 14 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
from ..cpu_offload import is_cpu_offload_enabled

__all__ = ["LayerNormLinear"]

@@ -94,8 +95,9 @@ def forward(
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
out_features, in_features = weight.shape
inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat)
@@ -339,7 +341,7 @@ def forward(
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.inp_shape = inp_shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
@@ -369,7 +371,7 @@ def forward(
out, _ = allreduce(out, tp_group)

# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
out = out.view(-1, *inp_shape[1:-1], out_features)

if return_layernorm_output:
if return_layernorm_output_gathered:
@@ -1149,7 +1151,7 @@ def forward(
with self.prepare_forward(inp, is_first_microbatch) as inp:

# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
@@ -1160,11 +1162,9 @@ def forward(
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
)
bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused

# Initialize FP8 weights if needed
weight_fp8 = None
@@ -1190,8 +1190,6 @@ def forward(
skip_update_flag=skip_fp8_weight_update,
)

from ..cpu_offload import CPUOffloadEnabled

if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
@@ -1200,8 +1198,8 @@ def forward(
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self._fast_get_param("layer_norm_weight"),
self._fast_get_param("layer_norm_bias"),
weight_tensor,
weight_fp8,
bias_tensor,
@@ -1212,7 +1210,7 @@ def forward(
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
Loading

0 comments on commit 9d976bc

Please sign in to comment.