diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 64d6f3f6b6..260adfc6d3 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -23,9 +23,12 @@ jobs: args: ${{ env.args }} # This job only runs for pull request comments - if: | - contains( ',ptrendx,ksivaman,', format(',{0},', github.actor)) && + if: > github.event.comment.body == '/blossom-ci' + && ( + github.actor == 'ptrendx' + || github.actor == 'ksivaman' + ) steps: - name: Check if comment is issued by authorized person run: blossom-ci diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2770919947..896d8f927e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,7 +12,7 @@ jobs: name: 'Core' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/cuda:12.5.0-devel-ubuntu22.04 + image: nvcr.io/nvidia/cuda:12.0.0-devel-ubuntu22.04 options: --user root steps: - name: 'Dependencies' @@ -35,9 +35,14 @@ jobs: name: 'PyTorch' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/pytorch:24.05-py3 + image: nvcr.io/nvidia/cuda:12.5.0-devel-ubuntu22.04 options: --user root steps: + - name: 'Dependencies' + run: | + apt-get update + apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 + pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -48,7 +53,8 @@ jobs: NVTE_FRAMEWORK: pytorch MAX_JOBS: 1 - name: 'Sanity check' - run: python tests/pytorch/test_sanity_import.py + if: false # Sanity import test requires Flash Attention + run: python3 tests/pytorch/test_sanity_import.py jax: name: 'JAX' runs-on: ubuntu-latest @@ -70,7 +76,7 @@ jobs: name: 'PaddlePaddle' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/paddlepaddle:24.05-py3 + image: nvcr.io/nvidia/paddlepaddle:24.07-py3 options: --user root steps: - name: 'Checkout' diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 7a6d269573..cd47fa9a54 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -15,9 +15,25 @@ jobs: args: ${{ env.args }} # This job only runs for pull request comments - if: | - contains( ',ptrendx,ksivaman,schetlur-nv,timmoon10,zlsh80826,mingxu1067,cyanguwa,nzmora-nvidia,galagam,nouiz,denera,sudhakarsingh27,Oleg-Goncharov,phu0ngng,nvcforster,', format(',{0},', github.actor)) && + if: > startsWith(github.event.comment.body, '/te-ci') + && ( + github.actor == 'ptrendx' + || github.actor == 'ksivaman' + || github.actor == 'schetlur-nv' + || github.actor == 'timmoon10' + || github.actor == 'zlsh80826' + || github.actor == 'mingxu1067' + || github.actor == 'cyanguwa' + || github.actor == 'nzmora-nvidia' + || github.actor == 'galagam' + || github.actor == 'nouiz' + || github.actor == 'denera' + || github.actor == 'sudhakarsingh27' + || github.actor == 'Oleg-Goncharov' + || github.actor == 'phu0ngng' + || github.actor == 'xrennvidia' + ) steps: - name: Check if comment is issued by authorized person run: blossom-ci diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 98ca4e1941..2533f5e5c1 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019 +Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b diff --git a/README.rst b/README.rst index 085c91ca49..25ed8af1de 100644 --- a/README.rst +++ b/README.rst @@ -149,8 +149,8 @@ Installation Pre-requisites ^^^^^^^^^^^^^^^^^^^^ * Linux x86_64 -* CUDA 11.8+ for Hopper and CUDA 12.1+ for Ada -* NVIDIA Driver supporting CUDA 11.8 or later +* CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada +* NVIDIA Driver supporting CUDA 12.0 or later * cuDNN 8.1 or later * For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later. @@ -182,7 +182,7 @@ From source Compiling with FlashAttention-2 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. +Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index e5df485eda..bfd7bf8471 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -11,9 +11,7 @@ import transformer_engine from tests.pytorch.fused_attn.test_fused_attn import ( ModelConfig, - _is_flash_attention_supported, - _is_fused_attention_supported, - _is_unfused_attention_supported, + _get_attention_backends, _run_dot_product_attention, ) @@ -29,8 +27,6 @@ workspace_opt = True # QKV memory layout qkv_layout = "bshd_bshd_bshd" -# sliding window attention -swa = False # padding between sequences for qkv_format=thd pad_between_seqs = False # training mode @@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -205,13 +197,15 @@ def main(): ) for model in model_configs.keys(): config = model_configs[model] - fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( + available_backends, fused_attn_backends = _get_attention_backends( config, - dtype, + qkv_dtype=dtype, qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, ) - fused_attn_supported = fused_attn_supported and not swa - flash_attn_supported = _is_flash_attention_supported(config) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + print( f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' f'{" and flash-attention" if flash_attn_supported else ""}...' diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index f8e233b273..81c871de46 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.9.0 +1.10.0 diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 631b2b3627..f71cef08ea 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -10,6 +10,7 @@ import sys import sysconfig import copy +import time from pathlib import Path from subprocess import CalledProcessError @@ -69,8 +70,8 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: configure_command.append(f"-Dpybind11_DIR={pybind11_dir}") # CMake build and install commands - build_command = [_cmake_bin, "--build", build_dir] - install_command = [_cmake_bin, "--install", build_dir] + build_command = [_cmake_bin, "--build", build_dir, "--verbose"] + install_command = [_cmake_bin, "--install", build_dir, "--verbose"] # Check whether parallel build is restricted max_jobs = get_max_jobs_for_parallel_build() @@ -81,6 +82,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: build_command.append(str(max_jobs)) # Run CMake commands + start_time = time.perf_counter() for command in [configure_command, build_command, install_command]: print(f"Running command {' '.join(command)}") try: @@ -88,6 +90,9 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: except (CalledProcessError, OSError) as e: raise RuntimeError(f"Error when running CMake: {e}") + total_time = time.perf_counter() - start_time + print(f"Time for build_ext: {total_time:.2f} seconds") + def get_build_ext(extension_cls: Type[setuptools.Extension]): class _CMakeBuildExtension(extension_cls): diff --git a/build_tools/jax.py b/build_tools/jax.py index 72a22f683e..f829230f50 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -2,7 +2,8 @@ # # See LICENSE for license information. -"""Paddle-paddle related extensions.""" +"""JAX related extensions.""" +import os from pathlib import Path import setuptools @@ -12,6 +13,25 @@ from typing import List +def xla_path() -> str: + """XLA root path lookup. + Throws FileNotFoundError if XLA source is not found.""" + + try: + from jax.extend import ffi + except ImportError: + if os.getenv("XLA_HOME"): + xla_home = Path(os.getenv("XLA_HOME")) + else: + xla_home = "/opt/xla" + else: + xla_home = ffi.include_dir() + + if not os.path.isdir(xla_home): + raise FileNotFoundError("Could not find xla source.") + return xla_home + + def setup_jax_extension( csrc_source_files, csrc_header_files, @@ -27,12 +47,14 @@ def setup_jax_extension( # Header files cuda_home, _ = cuda_path() + xla_home = xla_path() include_dirs = [ cuda_home / "include", common_header_files, common_header_files / "common", common_header_files / "common" / "include", csrc_header_files, + xla_home, ] # Compile flags diff --git a/build_tools/paddle.py b/build_tools/paddle.py index 163f094fce..f410682875 100644 --- a/build_tools/paddle.py +++ b/build_tools/paddle.py @@ -6,6 +6,7 @@ from pathlib import Path import setuptools +import os from .utils import cuda_version @@ -61,12 +62,18 @@ def setup_paddle_extension( except FileNotFoundError: print("Could not determine CUDA Toolkit version") else: - if version >= (11, 2): - nvcc_flags.extend(["--threads", "4"]) - if version >= (11, 0): - nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) - if version >= (11, 8): - nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + if version < (12, 0): + raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") + nvcc_flags.extend( + ( + "--threads", + os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), + "-gencode", + "arch=compute_80,code=sm_80", + "-gencode", + "arch=compute_90,code=sm_90", + ) + ) # Construct Paddle CUDA extension sources = [str(path) for path in sources] diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e423ffe907..f932f0695e 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -67,12 +67,18 @@ def setup_pytorch_extension( except FileNotFoundError: print("Could not determine CUDA Toolkit version") else: - if version >= (11, 2): - nvcc_flags.extend(["--threads", "4"]) - if version >= (11, 0): - nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) - if version >= (11, 8): - nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + if version < (12, 0): + raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") + nvcc_flags.extend( + ( + "--threads", + os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), + "-gencode", + "arch=compute_80,code=sm_80", + "-gencode", + "arch=compute_90,code=sm_90", + ) + ) # Libraries library_dirs = [] diff --git a/build_tools/utils.py b/build_tools/utils.py index 3230ad35bf..27ceea844b 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -14,7 +14,7 @@ import importlib from pathlib import Path from subprocess import CalledProcessError -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union @functools.lru_cache(maxsize=None) @@ -37,8 +37,8 @@ def get_max_jobs_for_parallel_build() -> int: num_jobs = 0 # Check environment variable - if os.getenv("NVTE_MAX_BUILD_JOBS"): - num_jobs = int(os.getenv("NVTE_MAX_BUILD_JOBS")) + if os.getenv("NVTE_BUILD_MAX_JOBS"): + num_jobs = int(os.getenv("NVTE_BUILD_MAX_JOBS")) elif os.getenv("MAX_JOBS"): num_jobs = int(os.getenv("MAX_JOBS")) @@ -254,12 +254,39 @@ def get_frameworks() -> List[str]: return _frameworks -def copy_common_headers(te_src, dst): - headers = te_src / "common" - for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True): - new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :]) - Path(new_path).parent.mkdir(exist_ok=True, parents=True) - shutil.copy(file_path, new_path) +def copy_common_headers( + src_dir: Union[Path, str], + dst_dir: Union[Path, str], +) -> None: + """Copy headers from core library + + src_dir should be the transformer_engine directory within the root + Transformer Engine repository. All .h and .cuh files within + transformer_engine/common are copied into dst_dir. Relative paths + are preserved. + + """ + + # Find common header files in src dir + headers = glob.glob( + os.path.join(str(src_dir), "common", "**", "*.h"), + recursive=True, + ) + headers.extend( + glob.glob( + os.path.join(str(src_dir), "common", "**", "*.cuh"), + recursive=True, + ) + ) + headers = [Path(path) for path in headers] + + # Copy common header files to dst dir + src_dir = Path(src_dir) + dst_dir = Path(dst_dir) + for path in headers: + new_path = dst_dir / path.relative_to(src_dir) + new_path.parent.mkdir(exist_ok=True, parents=True) + shutil.copy(path, new_path) def install_and_import(package): @@ -269,7 +296,7 @@ def install_and_import(package): globals()[main_package] = importlib.import_module(main_package) -def uninstall_te_fw_packages(): +def uninstall_te_wheel_packages(): subprocess.check_call( [ sys.executable, @@ -277,6 +304,7 @@ def uninstall_te_fw_packages(): "pip", "uninstall", "-y", + "transformer_engine_cu12", "transformer_engine_torch", "transformer_engine_paddle", "transformer_engine_jax", diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index a0bcd80347..7d839958cb 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"] +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 602d99ed4d..7dedf2a761 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 1896fc4e42..7682a2b6aa 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -5,10 +5,11 @@ set -e PLATFORM=${1:-manylinux_2_28_x86_64} -BUILD_COMMON=${2:-true} -BUILD_JAX=${3:-true} +BUILD_METAPACKAGE=${2:-true} +BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} -BUILD_PADDLE=${5:-true} +BUILD_JAX=${5:-true} +BUILD_PADDLE=${6:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} @@ -20,12 +21,33 @@ cd /TransformerEngine git checkout $TARGET_BRANCH git submodule update --init --recursive +if $BUILD_METAPACKAGE ; then + cd /TransformerEngine + NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt + mv dist/* /wheelhouse/ +fi + if $BUILD_COMMON ; then + VERSION=`cat build_tools/VERSION.txt` + WHL_BASE="transformer_engine-${VERSION}" + + # Create the wheel. /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + + # Repack the wheel for cuda specific package, i.e. cu12. + /opt/python/cp38-cp38/bin/wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" + /opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE} + + # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" - whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}" - mv dist/"$whl_name" /wheelhouse/"$whl_name_target" + whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" + rm -rf $WHL_BASE dist + mv *.whl /wheelhouse/"$whl_name_target" fi if $BUILD_PYTORCH ; then @@ -37,8 +59,8 @@ fi if $BUILD_JAX ; then cd /TransformerEngine/transformer_engine/jax - /opt/python/cp38-cp38/bin/pip install jax jaxlib - /opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + /opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib + /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi @@ -48,30 +70,30 @@ if $BUILD_PADDLE ; then dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 cd /TransformerEngine/transformer_engine/paddle - /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl + /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt - /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl + /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt - /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl + /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt - /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl + /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt - /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl + /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt - /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu mv dist/* /wheelhouse/ fi diff --git a/docs/conf.py b/docs/conf.py index 77751994d8..7a50ce76cf 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -47,7 +47,10 @@ git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha -version = str(te_version + "-" + git_sha) +if "dev" in te_version: + version = str(te_version + "-" + git_sha) +else: + version = str(te_version) release = te_version # hack: version is used for html creation, so put the version picker diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index cd8ab85ba2..85ce01079c 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -6,7 +6,6 @@ import torch from typing import Tuple from tests.pytorch.fused_attn.test_fused_attn import ModelConfig -from transformer_engine.pytorch.distributed import _set_cuda_rng_state from transformer_engine.pytorch.attention import DotProductAttention # Initialize RNG state @@ -22,7 +21,7 @@ def reset_rng_states() -> None: """Revert back to initial RNG state""" torch.set_rng_state(_cpu_rng_state) - _set_cuda_rng_state(_cuda_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) def _run_dot_product_attention( @@ -40,7 +39,7 @@ def _run_dot_product_attention( [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" ) inp = torch.randn( - [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim], + [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk], dtype=dtype, device="cuda", ) @@ -51,7 +50,7 @@ def _run_dot_product_attention( k.requires_grad = True v.requires_grad = True out_grad = torch.randn( - [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim], + [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v], dtype=dtype, device="cuda", ) @@ -80,7 +79,7 @@ def _run_dot_product_attention( block = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, qkv_format="bshd", attention_dropout=config.dropout_p, @@ -89,6 +88,8 @@ def _run_dot_product_attention( get_rng_state_tracker=None, tp_group=None, layer_number=1, + attn_mask_type="no_mask", + window_size=(-1, -1), ).to(dtype=dtype, device="cuda") # Run a forward and backward pass @@ -103,6 +104,7 @@ def _run_dot_product_attention( attn_mask_type=config.attn_mask_type, # 'arbitrary' core_attention_bias_type=config.attn_bias_type, # 'no_bias' core_attention_bias=bias, # None + window_size=(-1, -1), ) out.backward(out_grad) @@ -116,6 +118,7 @@ def _run_dot_product_attention( attn_mask_type=config.attn_mask_type, # no_mask core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias' core_attention_bias=bias, # bias + window_size=(-1, -1), ) out.backward(out_grad) @@ -133,6 +136,7 @@ def _run_dot_product_attention( config = model_configs["test_bias"] fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") +print() print("Run with arbitrary mask:") config = model_configs["test_mask"] unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") @@ -140,4 +144,6 @@ def _run_dot_product_attention( torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2) for i in range(3): torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2) + +print() print("Test passed!") diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 515f420790..27017b4773 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "8ae3bc43", + "id": "040f466a", "metadata": {}, "source": [ "# Attention Is All You Need!\n", @@ -23,7 +23,7 @@ }, { "cell_type": "markdown", - "id": "47421c01", + "id": "89a7d849", "metadata": {}, "source": [ "## 1. Attention Backends\n", @@ -71,7 +71,7 @@ }, { "cell_type": "markdown", - "id": "e52f60f0", + "id": "c90a2573", "metadata": {}, "source": [ "### 1.1 Flash vs. Non-Flash\n", @@ -85,30 +85,30 @@ "- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n", "\n", "
\n", - "Note \n", + "Note: \n", " \n", - "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", + "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", "
\n" ] }, { "cell_type": "markdown", - "id": "bb909ac4", + "id": "b5ce567d", "metadata": {}, "source": [ "### 1.2 flash-attention\n", "\n", "The flash-attention backend, available only in PyTorch, is a module wrapped around the public `flash-attn` package [[3]](https://github.com/Dao-AILab/flash-attention). \n", "\n", - "The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n", + "The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n", "\n", - "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", + "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", "\n", - "To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", + "To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", "\n", "### 1.3 cuDNN Attention\n", "\n", - "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n", + "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n", "\n", "\n", " \n", @@ -153,14 +153,14 @@ " \n", "
\n", "\n", - "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n", + "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n", "\n", "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n", "- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n", - "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three without transposes (see Section 3.1 for more details).\n", + "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n", "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n", - "- flash-attention supports sliding window attention, and cuDNN attention does not.\n", - "- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n", + "- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n", + "- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n", "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n", "\n", "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0." @@ -169,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a380859", + "id": "c5b8e3d7", "metadata": {}, "outputs": [], "source": [ @@ -184,25 +184,25 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "0584bb01", + "execution_count": 1, + "id": "50852cb5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Device 0: NVIDIA H100 PCIe GPU, sm90 compute capability, 79.1GB memory\n", + "Device 0: NVIDIA H100 80GB HBM3 GPU, sm90 compute capability, 79.1GB memory\n", "Running test_0 with cuDNN attention and flash-attention...\n", "Running test_1 with cuDNN attention and flash-attention...\n", "Running test_2 with cuDNN attention...\n", "Running test_3 with cuDNN attention and flash-attention...\n", "\n", " cuDNN fwd+bwd (ms) flash-attn fwd+bwd (ms) cuDNN vs flash speedup\n", - "test_0 0.0638 0.0858 1.3454\n", - "test_1 0.5415 0.7496 1.3842\n", - "test_2 1.2302 0.0000 0.0000\n", - "test_3 12.0122 19.0716 1.5877\n" + "test_0 0.0340 0.0468 1.3786\n", + "test_1 0.3664 0.5850 1.5968\n", + "test_2 0.9332 0.0000 0.0000\n", + "test_3 7.4875 11.8879 1.5877\n" ] } ], @@ -212,7 +212,7 @@ }, { "cell_type": "markdown", - "id": "45e53fc9", + "id": "9a615119", "metadata": {}, "source": [ "## 2. Backend Selection\n", @@ -253,35 +253,35 @@ }, { "cell_type": "markdown", - "id": "6dfeade3", + "id": "e6c0f3f0", "metadata": {}, "source": [ "### 2.1 Debug Information\n", "\n", - "To find out which backend is being used during runtime, users can turn on these debugging flags. Logging is done using the `logging` package.\n", + "To find out which backend is being used during runtime, we have the following two debugging flags. Logging is done by using the `logging` package.\n", "```\n", "NVTE_DEBUG = 0/1 # disables/enables debugging\n", "NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n", "```\n", "
\n", - "Note\n", + "Note:\n", " \n", - "These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n", + "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n", "
" ] }, { "cell_type": "markdown", - "id": "7e3b7981", + "id": "16660323", "metadata": {}, "source": [ - "The [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime." + "The example script [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend is used in runtime." ] }, { "cell_type": "code", - "execution_count": 22, - "id": "961c51d4", + "execution_count": 24, + "id": "906b8cf1", "metadata": {}, "outputs": [ { @@ -293,7 +293,7 @@ "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n", "\n", "Run flash-attention...\n", - "[INFO | DotProductAttention]: Running with FlashAttention backend \n", + "[INFO | DotProductAttention]: Running with FlashAttention backend\n", "\n", "Test passed.\n" ] @@ -305,16 +305,16 @@ }, { "cell_type": "markdown", - "id": "11bfbbd7", + "id": "8ca99461", "metadata": {}, "source": [ - "To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, " + "`NVTE_DEBUG_LEVEL=2` allows us to find out more about the backend selection logic. Users are encouraged to double check the `config` and provide it to the Transformer Engine team if they would like to file a bug. " ] }, { "cell_type": "code", - "execution_count": 25, - "id": "162a2be1", + "execution_count": 23, + "id": "d3637094", "metadata": {}, "outputs": [ { @@ -323,16 +323,18 @@ "text": [ "\n", "Run cuDNN attention...\n", + "[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': , 'cudnn_version': '9.3.0', 'qkv_type': , 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n", "[DEBUG | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0\n", + "[DEBUG | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}\n", + "[DEBUG | DotProductAttention]: Selected backend = FusedAttention (sub-backend 1)\n", "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n", - "[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': , 'flash_attn_version': , 'cudnn_version': '9.2.0'}\n", - "[DEBUG | FusedAttnFunc ]: Running forward in torch.bfloat16\n", - "[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n", "\n", "Run flash-attention...\n", + "[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': , 'cudnn_version': '9.3.0', 'qkv_type': , 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n", "[DEBUG | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0\n", - "[INFO | DotProductAttention]: Running with FlashAttention backend \n", - "[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': , 'flash_attn_version': , 'cudnn_version': '9.2.0'}\n", + "[DEBUG | DotProductAttention]: Available backends = {FlashAttention=True, FusedAttention=False, UnfusedDotProductAttention=True}\n", + "[DEBUG | DotProductAttention]: Selected backend = FlashAttention\n", + "[INFO | DotProductAttention]: Running with FlashAttention backend\n", "\n", "Test passed.\n" ] @@ -344,7 +346,7 @@ }, { "cell_type": "markdown", - "id": "779a51e6", + "id": "611d8fdb", "metadata": {}, "source": [ "### 2.2 User Control\n", @@ -392,28 +394,29 @@ }, { "cell_type": "markdown", - "id": "ccd5650d", + "id": "e60a2a3e", "metadata": {}, "source": [ "## 3. Backend Support\n", "\n", - "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n", + "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n", "\n", - "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Determinism Possible |\n", - "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n", - "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes (only for `bshd`,`sbhd`) | Yes |\n", - "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes (only for `bshd`,`thd`) | Yes |\n", - "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n", + "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", + "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", + "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", + "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n", + "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", + "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" ] }, { "cell_type": "markdown", - "id": "8439b389", + "id": "fbdcb327", "metadata": {}, "source": [ "### 3.1 QKV Layout\n", @@ -439,7 +442,7 @@ "**qkv_layout=thd_thd_thd:**\n", "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n", "\n", - "As of v1.7, Transformer Engine has the following support matrix.\n", + "As of v1.10, Transformer Engine has the following support matrix.\n", "\n", "\n", " \n", @@ -480,16 +483,16 @@ }, { "cell_type": "markdown", - "id": "0290f8e9", + "id": "855d9616", "metadata": {}, "source": [ "### 3.2 Attention Mask\n", "\n", - "Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n", + "Transformer Engine supports 7 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n", "\n", - "- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n", + "- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n", "\n", - "Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n", + "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n", "\n", "
\n", " \n", @@ -498,34 +501,25 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", "
Requires `attention_mask`
flash-attention`no_mask`, `causal`, `padding`, `padding_causal``no_mask`, `causal`: No
`padding`, `padding_causal`: Yes if `cu_seqlens` not provided
cuDNN attention`no_mask`, `causal`, `padding`, `padding_causal``no_mask`, `causal`: Noflash-attention
  • `no_mask`, `causal` (self-attention),
  • `padding`, `padding_causal` (self-attention),
  • `causal_bottom_right`, `padding_causal_bottom_right`
  • `no_mask`, `causal` `causal_bottom_right`: No
  • `padding`, `padding_causal`, `padding_causal_bottom_right`: Yes if `cu_seqlens` not provided
  • `arbitrary`: Yes
  • \n", - " `padding`, `padding_causal`: Yes if `cu_seqlens` not provided\n", - " cuDNN attention
  • `no_mask`, `causal`,
  • `padding`, `padding_causal`,
  • `causal_bottom_right`, `padding_causal_bottom_right`
  • Framework-native attention`no_mask`, `causal`, `arbitrary``no_mask`, `causal`: NoFramework-native attention
  • All (PyTorch)
  • `no_mask`, `causal`, `padding` (Jax, PaddlePaddle)
  • `arbitrary`: Yes
    \n", "\n", - "**padding and padding_causal:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", + "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", "\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", @@ -536,13 +530,13 @@ "\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "\n", - "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n" + "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.3. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "b1b7cdd4", + "execution_count": 33, + "id": "a1f25a9b", "metadata": {}, "outputs": [ { @@ -550,27 +544,29 @@ "output_type": "stream", "text": [ "Run with post_scale_bias:\n", - "[DotProductAttention]: using cuDNN attention (sub-backend 1)\n", + "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n", + "\n", "Run with arbitrary mask:\n", - "[DotProductAttention]: using unfused DPA\n", + "[INFO | DotProductAttention]: Running with UnfusedDotProductAttention backend\n", + "\n", "Test passed!\n" ] } ], "source": [ - "!NVTE_DEBUG=1 python arbitrary_mask_to_post_scale_bias.py" + "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python arbitrary_mask_to_post_scale_bias.py" ] }, { "cell_type": "markdown", - "id": "e045c284", + "id": "dda4a589", "metadata": {}, "source": [ "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", "\n", "### 3.3 Attention Bias\n", "\n", - "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n", + "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n", "\n", "\n", " \n", @@ -617,25 +613,20 @@ }, { "cell_type": "markdown", - "id": "8b8a4e40", + "id": "a0702339", "metadata": {}, "source": [ "### 3.4 FP8 Attention\n", "\n", "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n", "\n", - "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.7. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", + "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", "\n", "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n", "\n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "\n", - "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n", - "```\n", - "[DEBUG | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n", - "[DEBUG | FusedAttnFunc ]: Running forward in FP8\n", - "[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n", - "```" + "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." ] } ], diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 2ed7303417..15022005a1 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -11,9 +11,7 @@ import transformer_engine from tests.pytorch.fused_attn.test_fused_attn import ( ModelConfig, - _is_flash_attention_supported, - _is_fused_attention_supported, - _is_unfused_attention_supported, + _get_attention_backends, _run_dot_product_attention, ) @@ -60,7 +58,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -75,7 +72,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ckpt_attn, qkv_layout, workspace_opt, - swa, pad_between_seqs, is_training, ) @@ -94,13 +90,14 @@ def main(): models = ["test_0"] for model in models: config = model_configs[model] - fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( + available_backends, fused_attn_backends = _get_attention_backends( config, - dtype, + qkv_dtype=dtype, qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, ) - fused_attn_supported = fused_attn_supported and not swa - flash_attn_supported = _is_flash_attention_supported(config) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends example_attention(model, fused_attn_supported, flash_attn_supported) diff --git a/docs/installation.rst b/docs/installation.rst index 5dd10a79d1..012f3303cb 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -12,8 +12,8 @@ Prerequisites .. _driver link: https://www.nvidia.com/drivers 1. Linux x86_64 -2. `CUDA 11.8 `__ -3. |driver link|_ supporting CUDA 11.8 or later. +2. `CUDA 12.0 `__ +3. |driver link|_ supporting CUDA 12.0 or later. 4. `cuDNN 8.1 `__ or later. 5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 `__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 `__ or later. diff --git a/examples/jax/encoder/requirements.txt b/examples/jax/encoder/requirements.txt index 40b1915c96..26af82aa49 100644 --- a/examples/jax/encoder/requirements.txt +++ b/examples/jax/encoder/requirements.txt @@ -1,4 +1,4 @@ datasets flax>=0.7.1 -nltk +nltk>=3.8.2 optax diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 716d543d5b..646d6e0a12 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -168,7 +168,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt") + nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index c6223ed5bb..005ae50e72 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -147,7 +147,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt") + nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index c9620aa2be..286c064e96 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -250,7 +250,7 @@ def eval_model( def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt") + nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 674f7de815..363759afea 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -144,7 +144,7 @@ def eval_model(state, test_ds, batch_size, var_collect): def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt") + nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) diff --git a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py similarity index 50% rename from examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py rename to examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index 412c948a83..ab6b656be9 100644 --- a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -7,6 +7,8 @@ import os import sys import socket +import fcntl +import struct import argparse import warnings @@ -15,15 +17,37 @@ from torch.nn.parallel import DistributedDataParallel import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.common.recipe import Format, DelayedScaling +warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser( - description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers." + description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers." ) parser.add_argument( "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations." @@ -37,10 +61,10 @@ def _parse_args(argv=None, namespace=None): "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." ) parser.add_argument( - "--mlp-expansion-factor", - type=int, - default=4, - help="MLP block intermediate size as a factor of hidden dimension.", + "--layer-type", + type=_te_layer_argtype, + default=te.TransformerLayer, + help="Transformer Engine layer to train with comm+GEMM overlap.", ) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument( @@ -88,9 +112,57 @@ def _parse_args(argv=None, namespace=None): help="Print out additional debug information.", ) args = parser.parse_args(argv, namespace) + if args.bootstrap_backend == "nccl": + args.bind_to_device = True return args +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not config.no_comm_overlap + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + kwargs["ub_bulk_wgrad"] = not config.no_comm_overlap + kwargs["ub_bulk_dgrad"] = not config.no_comm_overlap + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + args.append(4 * hidden_size) + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + + def _train(opts): if "OMPI_COMM_WORLD_SIZE" in os.environ: # Execution with `mpirun -np N` @@ -110,19 +182,6 @@ def _train(opts): raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") NUM_NODES = WORLD_SIZE // LOCAL_SIZE - def dist_print(msg, group=None, end="\n", debug=False): - if debug and not opts.debug: - return - group = dist.new_group() if group is None else group - group_rank = dist.get_rank(group) - group_size = dist.get_world_size(group) - all_ranks = dist.get_process_group_ranks(group) - ranks_skip = all_ranks[1] - all_ranks[0] > 1 - group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size - if group_rank == 0 or opts.verbose: - print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True) - dist.barrier(group) - # Initialize torch.distributed global process group and get DP/TP groups torch.cuda.set_device(LOCAL_RANK) dist_init_kwargs = { @@ -143,75 +202,117 @@ def dist_print(msg, group=None, end="\n", debug=False): assert dist.is_nccl_available() dist.init_process_group(**dist_init_kwargs) nccl_world = dist.new_group(backend="nccl") - dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world) + + def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False): + if debug and not opts.debug: + return + group_rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout + if group_rank == src: + stream.write(f"[rank{WORLD_RANK}] {msg}{end}") + dist.barrier(group) + + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") # Figure out process groups for tensor- and data-parallelism (if any) if NUM_NODES > 1: # Create a list of world ranks on this node - hostnames = [None for _ in range(WORLD_SIZE)] hostname = socket.gethostname() + ifname = os.getenv( + "NVTE_UB_SOCKET_IFNAME", + os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + ) + + if ifname is not None: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + + hostnames = [None for _ in range(WORLD_SIZE)] dist.all_gather_object(hostnames, hostname) - node_ranks = [] + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + assert len(unique_hosts) == NUM_NODES + + ranks_per_node_list = [[] for _ in range(NUM_NODES)] + self_node_idx = -1 for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) if host == hostname: - node_ranks.append(i) + self_node_idx = node_idx + assert self_node_idx >= 0 + self_node_ranks = ranks_per_node_list[self_node_idx] if opts.num_replicas > 1: # Split node ranks into multiple replicas - assert len(node_ranks) % opts.num_replicas == 0 - tp_size = len(node_ranks) // opts.num_replicas - found_replica = False - for replica in range(opts.num_replicas): - start = replica * tp_size - end = start + tp_size - tp_ranks = node_ranks[start:end] - if WORLD_RANK in tp_ranks: - found_replica = True + assert len(self_node_ranks) % opts.num_replicas == 0 + tp_size = len(self_node_ranks) // opts.num_replicas + ranks_per_replica_list = [] + for node_ranks in ranks_per_node_list: + for i in range(opts.num_replicas): + start = i * tp_size + end = start + tp_size + ranks_per_replica_list.append(node_ranks[start:end]) + + self_replica_idx = -1 + for i, replica_ranks in enumerate(ranks_per_replica_list): + if WORLD_RANK in replica_ranks: + self_replica_idx = i break - assert found_replica + assert self_replica_idx >= 0 + else: # The entire node is the tensor-parallel group - tp_ranks = node_ranks - - tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) - tp_size = dist.get_world_size(tp_group) - tp_rank = dist.get_rank(tp_group) + ranks_per_replica_list = ranks_per_node_list + self_replica_idx = self_node_idx - # Data-parallelism across TP groups - dp_start = tp_rank - dp_end = dp_start + WORLD_SIZE - dp_ranks = list(range(dp_start, dp_end, tp_size)) - dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) + tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") + ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) else: if opts.num_replicas > 1: # Mixed data- and tensor-parallelism on a single node # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") - mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas)) - node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist() - - tp_ranks = mesh2d[node_idx[0], :].tolist() - tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) - - dp_ranks = mesh2d[:, node_idx[1]].tolist() - dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) + ranks_per_replica_tensor = all_ranks.reshape( + (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) + ) + tp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.tolist(), backend="nccl" + ) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) else: dp_group = None tp_group = nccl_world - tp_rank = dist.get_rank(tp_group) - tp_size = dist.get_world_size(tp_group) - + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) dist_print( f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", group=tp_group, ) if dp_group is not None: + dp_rank = dist.get_rank(dp_group) dist_print( f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", group=dp_group, ) + else: + dp_rank = 0 # Intialize userbuffers hidden_size = opts.num_heads * opts.head_dim @@ -226,26 +327,12 @@ def dist_print(msg, group=None, end="\n", debug=False): ) # Initialize the fused LayerNorm + Multi-layer Perceptron module - torch.manual_seed(opts.seed + tp_rank) + torch.manual_seed(opts.seed + dp_rank) torch.cuda.manual_seed(opts.seed + tp_rank) - model = te.LayerNormMLP( - hidden_size, - opts.mlp_expansion_factor * hidden_size, - params_dtype=torch.bfloat16, - device="cuda", - tp_group=tp_group, - tp_size=tp_size, - set_parallel_mode=True, - sequence_parallel=True, # this is required for comm+GEMM overlap - seq_length=opts.seq_length, - ub_overlap_rs=not opts.no_comm_overlap, - ub_overlap_ag=not opts.no_comm_overlap, - ub_overlap_rs_dgrad=not opts.no_comm_overlap, - ub_bulk_dgrad=False, - ub_bulk_wgrad=not opts.no_comm_overlap, - ) + layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size) + model = opts.layer_type(*layer_args, **layer_kwargs) if dp_group is not None: - model = DistributedDataParallel(model, process_group=dp_group) + model = DistributedDataParallel(model, dim=1, process_group=dp_group) # Initialize optimizer with model parameters optim = torch.optim.Adam(model.parameters(), lr=0.0001) @@ -255,28 +342,28 @@ def dist_print(msg, group=None, end="\n", debug=False): fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Start dummy "training" iterations - dist_print("Starting training iterations...", nccl_world) + dist_print("Starting training iterations...") for i in range(opts.num_iters): - dist_print(f" Iter {i+1}", tp_group, debug=True) - - dist_print(" |-- Generate random input batch", tp_group, debug=True) - x = torch.rand( - (opts.seq_length // tp_size, opts.batch_size, hidden_size), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - - dist_print(" |-- Forward pass", tp_group, debug=True) - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): - y = model(x) - dist_print(" |-- Compute loss", tp_group, debug=True) - loss = y.flatten().sum() - - dist_print(" |-- Backward pass", tp_group, debug=True) + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) loss.backward() - dist_print(" |-- Optimizer step", tp_group, debug=True) + dist_print(" |-- Optimizer step", group=tp_group, debug=True) optim.step() torch.cuda.synchronize() diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 4321432a2e..db3aa31951 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -4,11 +4,15 @@ set -xe +pip install "nltk>=3.8.2" pip install pytest==8.2.1 : ${TE_PATH:=/opt/transformerengine} pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' +# Test without custom calls +NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py + pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index 109633495b..2c3b832933 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -6,16 +6,30 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax + +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel + cd transformer_engine/jax -python setup.py sdist +NVTE_RELEASE_BUILD=1 python setup.py sdist -export NVTE_RELEASE_BUILD=0 pip install dist/* cd $TE_PATH -pip install dist/* +pip install dist/*.whl --no-deps python $TE_PATH/tests/jax/test_sanity_import.py diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh index e2d6d38dd4..30fbb1df1f 100644 --- a/qa/L0_paddle_wheel/test.sh +++ b/qa/L0_paddle_wheel/test.sh @@ -6,15 +6,28 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel==0.44.0 pydantic + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel -pip install dist/* -cd transformer_engine/paddle -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle -export NVTE_RELEASE_BUILD=0 +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel +pip install dist/*.whl --no-deps + +cd transformer_engine/paddle +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel pip install dist/* python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 90c5e499f3..e6ccf3b82f 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -19,7 +19,6 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py -pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index e108e93cdb..fd8457c44b 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -6,16 +6,30 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch + +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel + cd transformer_engine/pytorch -python setup.py sdist +NVTE_RELEASE_BUILD=1 python setup.py sdist -export NVTE_RELEASE_BUILD=0 pip install dist/* cd $TE_PATH -pip install dist/* +pip install dist/*.whl --no-deps python $TE_PATH/tests/pytorch/test_sanity_import.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 71c55851d5..50394c33a9 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -4,7 +4,12 @@ set -e +# pkg_resources is deprecated in setuptools 70+ and the packaging submodule +# has been removed from it. This is a temporary fix until upstream MLM fix. +pip install setuptools==69.5.1 + : ${TE_PATH:=/opt/transformerengine} +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py git clone https://github.com/NVIDIA/Megatron-LM.git cd Megatron-LM diff --git a/setup.py b/setup.py index 6a8bae2793..942f57d3c1 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,12 @@ """Installation script.""" import os +import time from pathlib import Path from typing import List, Tuple import setuptools +from wheel.bdist_wheel import bdist_wheel from build_tools.build_ext import CMakeExtension, get_build_ext from build_tools.utils import ( @@ -18,7 +20,8 @@ remove_dups, get_frameworks, install_and_import, - uninstall_te_fw_packages, + remove_dups, + uninstall_te_wheel_packages, ) from build_tools.te_version import te_version @@ -43,6 +46,16 @@ CMakeBuildExtension = get_build_ext(BuildExtension) +class TimedBdist(bdist_wheel): + """Helper class to measure build time""" + + def run(self): + start_time = time.perf_counter() + super().run() + total_time = time.perf_counter() - start_time + print(f"Total time for bdist_wheel: {total_time:.2f} seconds") + + def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" # Project directory root @@ -77,50 +90,85 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not found_pybind11(): setup_reqs.append("pybind11") + # Framework-specific requirements + if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + if "pytorch" in frameworks: + install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) + test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) + if "jax" in frameworks: + install_reqs.extend(["jax", "flax>=0.7.1"]) + test_reqs.extend(["numpy", "praxis"]) + if "paddle" in frameworks: + install_reqs.append("paddlepaddle-gpu") + test_reqs.append("numpy") + return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] if __name__ == "__main__": - # Dependencies - setup_requires, install_requires, test_requires = setup_requirements() - __version__ = te_version() - ext_modules = [setup_common_extension()] - if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): - # Remove residual FW packages since compiling from source - # results in a single binary with FW extensions included. - uninstall_te_fw_packages() - if "pytorch" in frameworks: - from build_tools.pytorch import setup_pytorch_extension - - ext_modules.append( - setup_pytorch_extension( - "transformer_engine/pytorch/csrc", - current_file_path / "transformer_engine" / "pytorch" / "csrc", - current_file_path / "transformer_engine", + with open("README.rst", encoding="utf-8") as f: + long_description = f.read() + + # Settings for building top level empty package for dependency management. + if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): + assert bool( + int(os.getenv("NVTE_RELEASE_BUILD", "0")) + ), "NVTE_RELEASE_BUILD env must be set for metapackage build." + ext_modules = [] + cmdclass = {} + package_data = {} + include_package_data = False + setup_requires = [] + install_requires = ([f"transformer_engine_cu12=={__version__}"],) + extras_require = { + "pytorch": [f"transformer_engine_torch=={__version__}"], + "jax": [f"transformer_engine_jax=={__version__}"], + "paddle": [f"transformer_engine_paddle=={__version__}"], + } + else: + setup_requires, install_requires, test_requires = setup_requirements() + ext_modules = [setup_common_extension()] + cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} + package_data = {"": ["VERSION.txt"]} + include_package_data = True + extras_require = {"test": test_requires} + + if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + # Remove residual FW packages since compiling from source + # results in a single binary with FW extensions included. + uninstall_te_wheel_packages() + if "pytorch" in frameworks: + from build_tools.pytorch import setup_pytorch_extension + + ext_modules.append( + setup_pytorch_extension( + "transformer_engine/pytorch/csrc", + current_file_path / "transformer_engine" / "pytorch" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) - if "jax" in frameworks: - from build_tools.jax import setup_jax_extension - - ext_modules.append( - setup_jax_extension( - "transformer_engine/jax/csrc", - current_file_path / "transformer_engine" / "jax" / "csrc", - current_file_path / "transformer_engine", + if "jax" in frameworks: + from build_tools.jax import setup_jax_extension + + ext_modules.append( + setup_jax_extension( + "transformer_engine/jax/csrc", + current_file_path / "transformer_engine" / "jax" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) - if "paddle" in frameworks: - from build_tools.paddle import setup_paddle_extension - - ext_modules.append( - setup_paddle_extension( - "transformer_engine/paddle/csrc", - current_file_path / "transformer_engine" / "paddle" / "csrc", - current_file_path / "transformer_engine", + if "paddle" in frameworks: + from build_tools.paddle import setup_paddle_extension + + ext_modules.append( + setup_paddle_extension( + "transformer_engine/paddle/csrc", + current_file_path / "transformer_engine" / "paddle" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) # Configure package setuptools.setup( @@ -133,15 +181,12 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: "transformer_engine/build_tools", ], ), - extras_require={ - "test": test_requires, - "pytorch": [f"transformer_engine_torch=={__version__}"], - "jax": [f"transformer_engine_jax=={__version__}"], - "paddle": [f"transformer_engine_paddle=={__version__}"], - }, + extras_require=extras_require, description="Transformer acceleration library", + long_description=long_description, + long_description_content_type="text/x-rst", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8, <3.13", classifiers=[ "Programming Language :: Python :: 3.8", @@ -153,6 +198,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: setup_requires=setup_requires, install_requires=install_requires, license_files=("LICENSE",), - include_package_data=True, - package_data={"": ["VERSION.txt"]}, + include_package_data=include_package_data, + package_data=package_data, ) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e302be57bd..e590d8e92a 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -10,7 +10,6 @@ add_executable(test_operator test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu - test_dgeglu.cu test_layernorm.cu test_rmsnorm.cu test_multi_cast_transpose.cu diff --git a/tests/cpp/operator/test_dgeglu.cu b/tests/cpp/operator/test_dgeglu.cu deleted file mode 100644 index 0924e2b4c9..0000000000 --- a/tests/cpp/operator/test_dgeglu.cu +++ /dev/null @@ -1,126 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include "../test_common.h" - -using namespace transformer_engine; - -namespace { - -template -inline CType gelu(const IType val) { - CType cval = val; - return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval))); -} - -template -inline CType dgelu(const IType val) { - CType cval = val; - const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); - return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) + - 0.5f * (1.f + tanh_out); -} - -template -void compute_ref_dgeglu(const IT *grad_h, const IT *input_h, OT *output_h, const size_t N, - const size_t H) { - const size_t col = H * 2; - - for (size_t i = 0; i < N; i++) { - for (size_t j = 0; j < H; j++) { - CT grad_elt = CT(grad_h[i * H + j]); - CT gelu_elt = CT(input_h[i * col + j]); - CT gate_elt = CT(input_h[i * col + H + j]); - - CT after_dgelu = dgelu(gelu_elt) * grad_elt * gate_elt; - CT after_dgate = grad_elt * gelu(gelu_elt); - - output_h[i * col + j] = OT(after_dgelu); - output_h[i * col + H + j] = OT(after_dgate); - } - } -} - -template -void performTestDGeGLU(const size_t N, const size_t H) { - using namespace test; - - using CType = fp32; - - DType itype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - Tensor grad({N, H}, itype); - Tensor input({N, H * 2}, itype); - Tensor output({N, H * 2}, otype); - - fillUniform(&grad); - fillUniform(&input); - - std::unique_ptr ref_output = std::make_unique(N * H * 2); - - nvte_dgeglu(grad.data(), input.data(), output.data(), 0); - - compute_ref_dgeglu(grad.cpu_dptr(), input.cpu_dptr(), - ref_output.get(), N, H); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - auto [atol, rtol] = getTolerances(otype); - compareResults("output_dgelu", output, ref_output.get(), atol, rtol); -} - -std::vector> test_cases = { - {4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}}; - -} // namespace - -class DGeGLUTestSuite - : public ::testing::TestWithParam>> {}; - -TEST_P(DGeGLUTestSuite, TestDGeGLU) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - output_type, OutputType, - performTestDGeGLU(size.first, size.second););); -} - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, DGeGLUTestSuite, - ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::ValuesIn(test_cases)), - [](const testing::TestParamInfo &info) { - std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second); - return name; - }); diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 55494c42d6..ccb6690a87 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -2,9 +2,12 @@ # # See LICENSE for license information. """conftest for tests/jax""" +import os import jax import pytest +from transformer_engine.transformer_engine_jax import get_device_compute_capability + @pytest.fixture(autouse=True, scope="function") def clear_live_arrays(): @@ -14,3 +17,19 @@ def clear_live_arrays(): yield for arr in jax.live_arrays(): arr.delete() + + +@pytest.fixture(autouse=True, scope="module") +def enable_fused_attn(): + """ + Enable fused attn for hopper+ arch. + Fused attn kernels on pre-hopper arch are not deterministic. + """ + if get_device_compute_capability(0) >= 90: + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + yield + if "NVTE_FUSED_ATTN" in os.environ: + del os.environ["NVTE_FUSED_ATTN"] + if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ: + del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 8664a03f8d..6991d83d4c 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -19,8 +19,10 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp +from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu from transformer_engine.jax import cpp_extensions as tex + GEMM_CASES = [ (256, 256, 512), (32, 32, 32), @@ -34,21 +36,6 @@ is_fp8_supported, reason = is_fp8_available() -def _convert_to_activation_function(fn_or_string): - """Convert a string to an activation function.""" - if fn_or_string == "linear": - return lambda x: x - if fn_or_string == "quick_gelu": - return lambda x: nn.gelu(x, approximate=True) - if fn_or_string == "squared_relu": - return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)]) - if isinstance(fn_or_string, str): - return getattr(nn, fn_or_string) - if callable(fn_or_string): - return fn_or_string - raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") - - class TestFP8Dot: @staticmethod @@ -293,14 +280,7 @@ def layernorm_fp8_mlp_ref( bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - x = jnp.split(linear_1_out, len(activation_type), axis=-2) - acts = [] - for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) - acts.append(x_i) - x = functools.reduce(operator.mul, acts) - - x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) + x = _jax_act_lu(linear_1_out, activation_type) fp8_meta_pkg_2 = FP8MetaPackage( amax_list_2[0], @@ -443,12 +423,7 @@ class TestActivationLu: def ref_func(self, x, activation_type): def ref_act_lu(inputs): - x = jnp.split(inputs, len(activation_type), axis=-2) - acts = [] - for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) - acts.append(x_i) - x = functools.reduce(operator.mul, acts) + x = _jax_act_lu(inputs, activation_type) return jnp.mean(x) ref_act_func = jit(value_and_grad(ref_act_lu, (0,))) @@ -457,7 +432,7 @@ def ref_act_lu(inputs): def primitive_func(self, inputs): return jnp.mean(activation_lu(inputs, activation_type=self.activation_type)) - @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)]) + @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)]) @pytest.mark.parametrize( "activation_type", [ @@ -475,7 +450,7 @@ def primitive_func(self, inputs): ) def test_activation_lu(self, random_inputs, activation_type): x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=1) + x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,))) @@ -536,7 +511,7 @@ def _prim_func_bwd(ctx, g): _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd) - dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype) + dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_axes], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) amax_no_use = jnp.zeros(1, jnp.float32) value_n_grad_primitive_func = value_and_grad( @@ -545,7 +520,7 @@ def _prim_func_bwd(ctx, g): return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)]) + @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)]) @pytest.mark.parametrize( "activation_type", [ @@ -566,10 +541,12 @@ def test_activation_lu(self, random_inputs, activation_type): self.scale = jnp.ones(1, jnp.float32) self.scale_inv = jnp.ones(1, jnp.float32) self.activation_type = activation_type - self.transpose_indices = (1, 2, 0) x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=1) + x = jnp.repeat(x, len(activation_type), axis=-2) + axes = jnp.arange(x.ndim) + self.transpose_axes = tuple([*axes[-2:]] + [*axes[:-2]]) + print(self.transpose_axes) prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x) ref_out, (ref_grad,) = self.ref_func(x, activation_type) @@ -581,7 +558,7 @@ def test_activation_lu(self, random_inputs, activation_type): assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose( prim_grad_trans, - jnp.transpose(ref_grad, self.transpose_indices), + jnp.transpose(ref_grad, self.transpose_axes), dtype=FP8Helper.BWD_DTYPE, ) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 796d5bcffa..390a3e2c4e 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -295,7 +295,10 @@ def _check_configs(self): if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: pytest.skip("Unsupported inputs combination or device compute capability.") - if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: + if ( + self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS + and self.bias_shape != BiasShape.BIAS_1HSS + ): if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: pytest.skip( "B1SS, BHSS and 11SS bias shapes are only supported for " @@ -391,7 +394,7 @@ def generate_random_segment_ids( return segment_ids, segment_pad if get_qkv_format(self.qkv_layout) == QKVFormat.THD: - self.num_segments_per_seq = 3 + self.num_segments_per_seq = 2 self.token_q, self.segment_pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) @@ -461,7 +464,8 @@ def test_forward(self): "dropout_probability": self.dropout_prob, "is_training": self.is_training, "qkv_layout": self.qkv_layout, - "max_segments_per_seq": self.num_segments_per_seq, + # +1 for testing runtime_segments < max_segments + "max_segments_per_seq": self.num_segments_per_seq + 1, } # Convert the outputs to float32 for the elementwise comparison @@ -518,7 +522,7 @@ def grad_func(func, *args, **kwargs): "dropout_probability": self.dropout_prob, "is_training": self.is_training, "qkv_layout": self.qkv_layout, - "max_segments_per_seq": self.num_segments_per_seq, + "max_segments_per_seq": self.num_segments_per_seq + 1, } # We can compute dBias only for the [1, h, s, s] layout diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 92a6c80028..ccab73088a 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -15,7 +15,6 @@ from utils import assert_allclose -from transformer_engine.transformer_engine_jax import get_device_compute_capability from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.jax import fp8_autocast, update_collections from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral @@ -43,19 +42,6 @@ FP8_FORMATS = [Format.E4M3, Format.HYBRID] -@pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): - """ - Enable fused attn for hopper+ arch. - Fused attn kernels on pre-hopper arch are not deterministic. - """ - if get_device_compute_capability(0) >= 90: - os.environ["NVTE_FUSED_ATTN"] = "1" - yield - if "NVTE_FUSED_ATTN" in os.environ: - del os.environ["NVTE_FUSED_ATTN"] - - def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): for key in ref_fd: assert key in test_fd, f"{key} not found in test dict {test_fd}" diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py index 0cff5955fa..49e32e503c 100644 --- a/tests/jax/test_softmax.py +++ b/tests/jax/test_softmax.py @@ -123,14 +123,12 @@ def grad_func(func, *args, **kwargs): # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( - value_and_grad( - lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,) - ) + value_and_grad(lambda logits, *args: grad_func(softmax, logits, *args, **kwargs), (0,)) ) jitted_reference = jit( value_and_grad( lambda logits, *args: grad_func( - __class__.reference_softmax, self.logits, *args, **kwargs + __class__.reference_softmax, logits, *args, **kwargs ), (0,), ) diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index 6a985d7e86..b519fc0a0f 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -872,8 +872,9 @@ def test_layernorm_mlp_fp8_microbatch( @pytest.mark.parametrize("attn_type", ["self", "cross"]) @pytest.mark.parametrize("mask_type", ["causal", "padding"]) @pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) +@pytest.mark.parametrize("deterministic", [True, False]) def test_dot_product_attention( - bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype + bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype, deterministic ): """ Test DotProductAttention Layer @@ -927,6 +928,10 @@ def test_dot_product_attention( attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False head_size = hidden_size // num_heads + + if deterministic: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + layer_te = te.DotProductAttention( num_heads, head_size, @@ -981,6 +986,15 @@ def calc_attn_output_and_grad(layer, q, k, v, mask, dout): assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol) assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) + if deterministic: + out2, q_grad2, k_grad2, v_grad2 = calc_attn_output_and_grad( + layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out + ) + assert_allclose(out, out2, rtol=1e-12, atol=1e-12) + assert_allclose(q_grad, q_grad2, rtol=1e-12, atol=1e-12) + assert_allclose(k_grad, k_grad2, rtol=1e-12, atol=1e-12) + assert_allclose(v_grad, v_grad2, rtol=1e-12, atol=1e-12) + os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) @pytest.mark.parametrize("bs", [1, 2]) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index d7dc3e1ce1..5ba70ccbdd 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -46,17 +46,20 @@ def _mapped_argtype(opt, typemap): def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.") parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument("-s", "--seq-length", type=int, default=512, help="Input sequence length.") parser.add_argument( - "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + "-n", "--num-heads", type=int, default=12, help="Number of attention heads." ) parser.add_argument( - "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." ) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument( "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." ) + parser.add_argument( + "--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM." + ) parser.add_argument( "--p2p", action="store_true", default=False, help="Test overlap with P2P comms." ) @@ -106,7 +109,7 @@ def _parse_args(argv=None, namespace=None): help="Set device clock speed to a fixed value via `nvidia-smi`.", ) parser.add_argument( - "--scale", type=float, default=1e-2, help="Set scaling factor for input and weight tensors." + "--std", type=float, default=0.023, help="Standard deviation for input and weight tensors." ) parser.add_argument( "--tcp-init", @@ -135,6 +138,9 @@ def _parse_args(argv=None, namespace=None): + "initialization." ), ) + parser.add_argument( + "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs." + ) parser.add_argument( "-v", "--verbose", action="store_true", default=False, help="Verbose info messages." ) @@ -150,14 +156,17 @@ def _parse_args(argv=None, namespace=None): if opts.fp8: warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") opts.fp8 = False - elif opts.comm_type == 1 and not opts.p2p: - warnings.warn("All-gather overlap is only supported with point-2-point comms.") - opts.p2p = True + elif opts.comm_type == 1: + if opts.atomic: + setattr(opts, "atomic_rs_p2p", opts.p2p) + if not opts.p2p: + warnings.warn("All-gather overlap is only supported with point-2-point comms.") + opts.p2p = True if opts.atomic: if not te.fp8.check_fp8_support(): assert not opts.fp8, "Atomic GEMM is only supported in FP8." - opts.fp8 = True + opts.fp8 = True return opts @@ -203,13 +212,14 @@ def _main(opts): print(f"[rank:{LOCAL_RANK}] {msg}\n", end="", flush=True) # Info printout - def dist_print(msg, src=None, info=False, section=False, group=None): + def dist_print(msg, src=None, info=False, error=False, section=False, group=None): group = dist.new_group() if group is None else group rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout if info or opts.verbose: if section: if rank == (0 if src is None else src): - print("\n", end="", flush=True) + stream.write("\n") dist.barrier(group) if src is None or rank == src: prefix = "[GLOBAL] " if src is not None else f"[rank:{rank}] " @@ -217,7 +227,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): msg = "\n".join( [prefix + lines[0]] + [(" " * len(prefix)) + line for line in lines[1:]] ) - print(msg + "\n", end="", flush=True) + stream.write(msg + "\n") dist.barrier(group) # Initialize torch.distributed global process group and get TP group @@ -312,7 +322,9 @@ def dist_print(msg, src=None, info=False, section=False, group=None): hidden_size = opts.num_heads * opts.head_dim inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) - ubuf_dtype = torch.uint8 if opts.fp8 and opts.comm_type == 1 else torch.bfloat16 + ubuf_dtype = torch.bfloat16 + if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output): + ubuf_dtype = torch.uint8 sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") ub_obj = ub_obj = ( tex.UbufP2PCommOverlap( @@ -331,7 +343,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): 3, # Max concurrent GEMM streams opts.comm_type == 0, # overlap with reduce scatter opts.atomic, # use a single GEMM with atomic-counters - True, # Use copy engine for P2P communications + not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), ub_callbacks, ) if opts.p2p @@ -349,7 +361,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): 4, # Number of communication splits True, # Set SM margin 3, # Max concurrent GEMM streams - opts.atomic, # uUe a single GEMM with atomic-counters + opts.atomic, # Use a single GEMM with atomic-counters ub_callbacks, ) ) @@ -357,25 +369,49 @@ def dist_print(msg, src=None, info=False, section=False, group=None): # Numerical check on AG + atomic GEMM requires testing an AG+RS pair ub_obj2 = None if opts.atomic and opts.comm_type == 1 and opts.check_numerics: - sample_buffer2 = torch.empty((outer_size, hidden_size), dtype=torch.bfloat16, device="cuda") - ub_obj2 = tex.UbufP2PCommOverlap( - sample_buffer2, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 1, # Number of communication SMs - 1, # CGA cluster size - True, # Set SM margin - False, # Aggregate 2X GEMM chunks - 3, # Max concurrent GEMM streams - True, # overlap with reduce scatter - True, # use a single GEMM with atomic-counters - True, # use copy engine for P2P communications - ub_callbacks, + sample_buffer2 = torch.empty( + (outer_size, hidden_size), + dtype=torch.uint8 if opts.fp8_output else torch.bfloat16, + device="cuda", + ) + ub_obj2 = ( + tex.UbufP2PCommOverlap( + sample_buffer2, # Sample userbuffer + WORLD_RANK, # World rank + WORLD_SIZE, # World size + LOCAL_RANK, # Rank within the node + LOCAL_SIZE, # Number of ranks/GPUs per node + 0, # Node ID + 1, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + 1, # Number of communication SMs + 1, # CGA cluster size + True, # Set SM margin + False, # Aggregate 2X GEMM chunks + 3, # Max concurrent GEMM streams + True, # overlap with reduce scatter + True, # use a single GEMM with atomic-counters + True, # use copy engine for P2P communications + ub_callbacks, + ) + if opts.atomic_rs_p2p + else tex.UbufCommOverlap( + sample_buffer2, # Sample userbuffer + WORLD_RANK, # World rank + WORLD_SIZE, # World size + LOCAL_RANK, # Rank within the node + LOCAL_SIZE, # Number of ranks/GPUs per node + 0, # Node ID + 1, # Number of nodes + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + 16, # Number of communication SMs + 2, # CGA cluster size + 4, # Number of communication splits + True, # Set SM margin + 3, # Max concurrent GEMM streams + True, # uUe a single GEMM with atomic-counters + ub_callbacks, + ) ) # Figure out problem sizing: @@ -409,43 +445,53 @@ def dist_print(msg, src=None, info=False, section=False, group=None): # Initialize distributed input tensor and GEMM kernels torch.manual_seed(opts.seed + tp_rank) torch.cuda.manual_seed(opts.seed + tp_rank) - inp = torch.mul(torch.rand(local_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale) - kernel_t = torch.mul( - torch.rand(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + inp = torch.nn.init.normal_( + torch.empty(local_inp_shape, dtype=torch.bfloat16, device="cuda"), + mean=0.0, + std=opts.std, + ) + kernel_t = torch.nn.init.normal_( + torch.empty(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"), + mean=0.0, + std=opts.std, ) if ub_obj2 is not None: - kernel2_t = torch.mul( - torch.rand(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + kernel2_t = torch.nn.init.normal_( + torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), + mean=0.0, + std=opts.std, ) # Gather global tensors and calculate reference result (need these first for Fp8 scales) if opts.bulk_overlap: ker_g = torch.transpose(kernel_t, 0, 1) inp_g = inp - bulk_inp = torch.mul( - torch.rand(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale + bulk_inp = torch.nn.init.normal_( + torch.empty(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"), + mean=0.0, + std=opts.std, ) else: if opts.comm_type == 1: # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) ker_g = torch.transpose( te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 - ) + ).to(dtype=torch.float32) # AG Input: (M/P, N) -> gather -> (M, N) - inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0] + inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0].to(dtype=torch.float32) if ub_obj2 is not None: ker2_g = te.distributed.gather_along_first_dim( torch.transpose(kernel2_t, 0, 1), tp_group - )[0] + )[0].to(dtype=torch.float32) else: # RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N) ker_g = te.distributed.gather_along_first_dim( torch.transpose(kernel_t, 0, 1), tp_group - )[0] + )[0].to(dtype=torch.float32) # RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) inp_g = torch.transpose( te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1 - ) + ).to(dtype=torch.float32) if opts.bulk_overlap: if opts.comm_type == 1: @@ -459,7 +505,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): else: ref_g = torch.matmul(inp_g, ker_g) if ub_obj2 is not None: - inp2_g = torch.mul(ref_g, opts.scale) + inp2_g = torch.nn.functional.gelu(ref_g) ref2_g = torch.matmul(inp2_g, ker2_g) if opts.fp8: @@ -483,7 +529,10 @@ def dist_print(msg, src=None, info=False, section=False, group=None): fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) - if ub_obj2 is not None: + if opts.bulk_overlap and opts.comm_type == 0: + bulk_amax = torch.max(torch.abs(bulk_inp)) + fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) + elif ub_obj2 is not None: inp2_amax = torch.max(torch.abs(inp2_g)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax) ker2_amax = torch.max(torch.abs(ker2_g)) @@ -502,7 +551,11 @@ def dist_print(msg, src=None, info=False, section=False, group=None): kernel_t_fp8 = tex.cast_to_fp8( kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype ) - if ub_obj2 is not None: + if opts.bulk_overlap and opts.comm_type == 0: + bulk_inp_fp8 = tex.cast_to_fp8( + bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype + ) + elif ub_obj2 is not None: kernel2_t_fp8 = tex.cast_to_fp8( kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype ) @@ -521,7 +574,14 @@ def dist_print(msg, src=None, info=False, section=False, group=None): rtol=0.125, atol=0.0675, ) - if ub_obj2 is not None: + if opts.bulk_overlap and opts.comm_type == 0: + torch.allclose( + bulk_inp.to(dtype=torch.float32), + bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], + rtol=0.125, + atol=0.0675, + ) + elif ub_obj2 is not None: torch.allclose( kernel2_t.to(dtype=torch.float32), kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], @@ -534,6 +594,8 @@ def dist_print(msg, src=None, info=False, section=False, group=None): ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) if ub_obj2 is not None: ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) + elif opts.bulk_overlap: + ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) else: ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT]) @@ -556,7 +618,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): ) else: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp, 0) + ub_obj.copy_input_to_ubuf(bulk_inp_fp8 if opts.fp8 else bulk_inp, 0) ubuf_out = None else: ubuf_out = ub_obj.get_ubuf_output(1) @@ -565,80 +627,131 @@ def dist_print(msg, src=None, info=False, section=False, group=None): (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" ) + # Wrap GEMM ops in condensed functions to make CUDA Graphs easier to use + def _fp8_gemm(): + return tex.fp8_gemm( + kernel_t_fp8, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype, + gemm_inp, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype, + torch.uint8 if opts.fp8_output else torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + use_split_accumulator=te.module.base._2X_ACC_FPROP, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out=ubuf_out, + D_dtype=fp8_dtype if opts.fp8_output else None, + fp8_meta_tensor=fp8_meta if opts.fp8_output else None, + out_index=tex.FP8FwdTensors.GEMM1_OUTPUT if opts.fp8_output else None, + ) + + def _fp8_gemm2(gemm1_out): + gemm2_inp = tex.gelu( + ( + tex.cast_from_fp8( + gemm1_out, + fp8_meta, + tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_dtype, + tex.DType.kFloat32, + ) + if opts.fp8_output + else gemm1_out + ), + fp8_meta, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype, + ) + return tex.fp8_gemm( + kernel2_t_fp8, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype, + gemm2_inp, + fp8_meta.scale_inv, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype, + torch.uint8 if opts.fp8_output else torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + use_split_accumulator=te.module.base._2X_ACC_FPROP, + ub_algo=( + tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + if opts.atomic_rs_p2p + else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ), + ub=ub_obj2, + extra_output_tensor=rs_out2, + out=ubuf_out2, + D_dtype=fp8_dtype if opts.fp8_output else None, + fp8_meta_tensor=fp8_meta if opts.fp8_output else None, + out_index=tex.FP8FwdTensors.GEMM2_OUTPUT if opts.fp8_output else None, + ) + + def _gemm(): + return tex.gemm( + kernel_t, + gemm_inp, + torch.bfloat16, + te.module.base.get_workspace(), + bias=None, + use_bias=False, + gelu=False, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out=ubuf_out, + ) + # Trigger GEMM total_iters = opts.warmup_iters + opts.timing_iters start_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)] torch.cuda.synchronize() - if opts.fp8: + if opts.use_cuda_graphs: + # Trace the CUDA graph first + g = torch.cuda.CUDAGraph() + if opts.fp8: + if ub_obj is None: + with torch.cuda.graph(g): + all_outputs = _fp8_gemm() + else: + with torch.cuda.graph(g): + all_outputs = _fp8_gemm() + _ = _fp8_gemm2(all_outputs[0]) + else: + with torch.cuda.graph(g): + all_outputs = _gemm() + + # Now replay the CUDA graph in a loop for i in range(total_iters): start_events[i].record() - all_outputs = tex.fp8_gemm( - kernel_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, - gemm_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, - ) + g.replay() end_events[i].record() - if ub_obj2 is not None: - gemm2_inp = tex.cast_to_fp8( - torch.mul(all_outputs[0], opts.scale), - fp8_meta, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, - ) - all_outputs = tex.fp8_gemm( - kernel2_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_WEIGHT, - fp8_dtype, - gemm2_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, - ub=ub_obj2, - extra_output_tensor=rs_out2, - out=ubuf_out2, - ) + else: for i in range(total_iters): - start_events[i].record() - all_outputs = tex.gemm( - kernel_t, - gemm_inp, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, - ) - end_events[i].record() + if opts.fp8: + start_events[i].record() + all_outputs = _fp8_gemm() + end_events[i].record() + if ub_obj2 is not None: + _fp8_gemm2(all_outputs[0]) + else: + start_events[i].record() + all_outputs = _gemm() + end_events[i].record() torch.cuda.synchronize() gpu_times = [ @@ -679,7 +792,11 @@ def dist_print(msg, src=None, info=False, section=False, group=None): ref_out = ref_g output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" - dist_print(output_info, src=0 if opts.comm_type == 0 else None, section=True) + dist_print( + output_info, + src=0 if opts.comm_type == 0 else None, + section=True, + ) test_nonzeros = torch.count_nonzero(test_out) ref_nonzeros = torch.count_nonzero(ref_out) @@ -691,11 +808,21 @@ def dist_print(msg, src=None, info=False, section=False, group=None): if opts.comm_type == 1: if ub_obj2 is not None: # AG+RS Output: (M/P, N) -> gather -> (M, N) - output = rs_out2 + output = rs_out2.to(dtype=torch.float32) test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] else: # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) - output = all_outputs[0] + output = ( + tex.cast_from_fp8( + all_outputs[0], + fp8_meta, + tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_dtype, + tex.DType.kFloat32, + ) + if opts.fp8_output + else all_outputs[0] + ) test_out = torch.transpose( te.distributed.gather_along_first_dim( torch.transpose(output, 0, 1), tp_group @@ -705,7 +832,7 @@ def dist_print(msg, src=None, info=False, section=False, group=None): ) else: # RS Output: (M/P, N) -> gather -> (M, N) - output = rs_out + output = rs_out.to(dtype=torch.float32) test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] if opts.fp8: @@ -755,30 +882,33 @@ def dist_print(msg, src=None, info=False, section=False, group=None): torch.cuda.synchronize() dist.barrier(tp_group) - test_out = test_out.to(dtype=torch.float32) - ref_out = ref_out.to(dtype=torch.float32) - error_below_tol = torch.allclose( - test_out, - ref_out, - rtol=0.125 if opts.fp8 else 0.02, - atol=0.0675 if opts.fp8 else 0.001, - ) diff = torch.abs(test_out - ref_out).flatten() m = torch.argmax(diff) abs_err = diff[m].item() - rel_err = abs_err / (ref_out.flatten()[m].item() + 1e-5) - if not error_below_tol: + rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5) + rtol = 0.125 if opts.fp8 else 0.02 + atol = 0.0625 if opts.fp8 else 0.001 + if rel_err > rtol and abs_err > atol: numerics_failed = True numerics_info = ( "NUMERICAL CHECK FAILED: " + f"Outputs not close enough at index {m.item()} " - + f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} " - + f"(abs error = {abs_err} | rel error = {rel_err})." + + f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} | " + + f"rel. error = {rel_err} (tol = {rtol}) | " + + f"abs. error = {abs_err} (tol = {atol})" ) else: - numerics_info = f"NUMERICAL CHECK PASSED: abs error = {abs_err} | rel error = {rel_err}" + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" - dist_print(numerics_info, src=0, section=True, info=True, group=tp_group) + dist_print( + numerics_info, src=0, section=True, info=True, error=numerics_failed, group=tp_group + ) dist.barrier(tp_group) if LOCAL_RANK == 0: diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py new file mode 100644 index 0000000000..e5653bda01 --- /dev/null +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -0,0 +1,352 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import socket +import argparse +import warnings +from functools import partial + +import torch +import torch.distributed as dist + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + + +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not reference + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not reference + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + kwargs["ub_bulk_wgrad"] = not reference + kwargs["ub_bulk_dgrad"] = not reference + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not reference + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + args.append(4 * hidden_size) + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not reference + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser( + description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers." + ) + parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP) + parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=12, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." + ) + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + ) + parser.add_argument( + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", + ) + parser.add_argument( + "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Print out additional debug information.", + ) + args = parser.parse_args(argv, namespace) + + if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!") + args.use_cuda_graphs = False + + return args + + +def _compare_tensors(name, test, ref, rtol, atol): + # Make sure tensors aren't zero and we don't pass trivially + if test.count_nonzero() == 0: + if ref.count_nonzero() == 0: + warnings.warn( + f"WARNING: {name} is a zero-tensor for both test and reference models!", + category=RuntimeWarning, + ) + else: + numerics_info = ( + f"NUMERICAL CHECK FAILED: {name} is a zero-tensor but does not match reference!" + ) + return 1, numerics_info + + diff = torch.abs(test - ref).flatten() + m = torch.argmax(diff) + abs_err = diff[m].item() + rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5) + numerics_failed = 0 + if rel_err > rtol and abs_err > atol: + numerics_failed = 1 + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"{name} not close enough at index {m.item()} " + + f"with {test.flatten()[m].item()} vs {ref.flatten()[m].item()} | " + + f"rel. error = {rel_err} (tol = {rtol}) | " + + f"abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = f"NUMERICAL CHECK PASSED: {name} | " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err <= atol else "." + ) + if abs_err <= atol: + numerics_info += f" abs. error = {abs_err} (tol = {atol})" + + return numerics_failed, numerics_info + + +def _train(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bind_to_device = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + assert LOCAL_SIZE == WORLD_SIZE + + def dist_print(msg, src=None, end="\n", debug=False, error=False): + if debug and not opts.debug: + return + stream = sys.stderr if error else sys.stdout + if WORLD_RANK == (0 if src is None else src): + stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") + dist.barrier() + + # Set device and initialize RNG states + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(opts.seed) + torch.cuda.manual_seed(opts.seed) + + # Initialize torch.distributed global process group and get DP/TP groups + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init: + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + + # Intialize userbuffers + te.module.base.initialize_ub( + [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], + WORLD_SIZE, + use_fp8=opts.fp8, + dtype=torch.bfloat16, + bootstrap_backend=opts.bootstrap_backend, + ) + + # Initialize the Transformer Engine layer with overlap + args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE) + with te.fp8_model_init(enabled=opts.fp8_init): + test_model = opts.layer_type(*args, **kwargs) + dist_print("Initialized test model...", debug=True) + + # Initialize the reference model and copy all parameters + ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True) + with te.fp8_model_init(enabled=opts.fp8_init): + ref_model = opts.layer_type(*ref_args, **ref_kwargs) + dist_print("Initialized reference model...", debug=True) + for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()): + with torch.no_grad(): + ref_param.copy_(test_param) + torch.testing.assert_close(test_param, ref_param, rtol=0.0, atol=0.0) + dist_print("Copied parameters from test model to reference model...", debug=True) + + # Fp8 recipe setup + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + # Prepare random input tensors + test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) + test_x.retain_grad() + ref_x = torch.empty_like(test_x).requires_grad_(True) + with torch.no_grad(): + ref_x.copy_(test_x) + torch.testing.assert_close(test_x, ref_x, rtol=0.0, atol=0.0) + ref_x.retain_grad() + + # Execute fwd/bwd and collect tensors to test + def run_fwd_bwd(model, x): + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + loss = out.sum() + loss.backward() + return out + + torch_rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}")) + if opts.use_cuda_graphs: + test_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(test_graph): + test_out = run_fwd_bwd(test_model, test_x) + test_graph.replay() + del test_graph + else: + test_out = run_fwd_bwd(test_model, test_x) + test_grads = [test_out, test_x.grad] + names = ["output", "input.grad"] + for test_name, test_param in test_model.named_parameters(): + if test_param.requires_grad and "layer_norm" not in test_name: + test_grads.append(test_param.grad) + names.append(test_name + ".grad") + + torch.set_rng_state(torch_rng_state) + torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}")) + if opts.use_cuda_graphs: + ref_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(ref_graph): + ref_out = run_fwd_bwd(ref_model, ref_x) + ref_graph.replay() + del ref_graph + else: + ref_out = run_fwd_bwd(ref_model, ref_x) + ref_grads = [ref_out, ref_x.grad] + for ref_name, ref_param in ref_model.named_parameters(): + if ref_param.requires_grad and "layer_norm" not in ref_name: + ref_grads.append(ref_param.grad) + + # Make sure we have the same number of gradients + numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + if len(test_grads) != len(ref_grads): + numerics_failed[0] = 1 + numerics_info = ( + "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + + f"expected {len(ref_grads)} but got {len(test_grads)}." + ) + dist_print(numerics_info, src=WORLD_RANK, error=True) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + + # Now validate accuracy + if not bool(numerics_failed.item()): + for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): + rtol = 0.125 if opts.fp8 else 0.025 + atol = 0.0625 if opts.fp8 else 0.00125 + grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) + dist_print(grad_info, src=WORLD_RANK, error=grad_failed) + numerics_failed[0] = int(grad_failed) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + if bool(numerics_failed.item()): + break + + te.module.base.destroy_ub() + dist_print("Destroying Userbuffers objects...", debug=True) + + dist_print("Destroying all process groups...", debug=True) + dist.destroy_process_group() + if opts.debug and WORLD_RANK == 0: + print("Exiting...\n", end="", flush=True) + + return numerics_failed[0].item() + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index d0745aebf6..63310195ae 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -7,16 +7,27 @@ import pytest import torch +import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +if torch.cuda.device_count() < 2: + pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") + fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() RNG_SEED: int = 1234 -SEQ_LENGTH: int = 2024 +SEQ_LENGTH: int = 512 BATCH_SIZE: int = 2 -NUM_HEADS: int = 64 -HEAD_DIM: int = 128 +NUM_HEADS: int = 12 +HEAD_DIM: int = 64 +TE_LAYERS = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, +] TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(torch.cuda.device_count(), 4) @@ -32,66 +43,28 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.") -@pytest.mark.parametrize( - "fp8,p2p,comm_type,aggregate,atomic,bulk", - [ - # FP8, P2P, Type, Aggregate, Atomic, Bulk - (False, True, "AG", False, False, False), - (False, True, "AG", True, False, False), - (True, True, "AG", False, False, False), - (True, True, "AG", True, False, False), - (False, False, "RS", False, False, False), - (False, True, "RS", False, False, False), - (True, False, "RS", False, False, False), - (True, True, "RS", False, False, False), - (True, False, "RS", False, True, False), - (True, True, "RS", False, True, False), - (False, False, "AG", False, False, True), - (False, False, "RS", False, False, True), - ], - ids=[ - " AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ", - " AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ", - " AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ", - " AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ", - " SPLIT GEMM -> RS | BF16 | PIPELINE ", - " SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ", - " SPLIT GEMM -> RS | FP8 | PIPELINE ", - " SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ", - " ATOMIC GEMM -> RS | FP8 | PIPELINE ", - " ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ", - " BULK AG & GEMM | BF16 | PIPELINE ", - " BULK RS & GEMM | BF16 | PIPELINE ", - ], -) -def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk): - """ - Test comm+GEMM overlap algorithms with direct calls to - te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm - """ +def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): test_path = TEST_ROOT / "run_gemm_with_overlap.py" - test_cmd = ( - LAUNCH_CMD - + [str(test_path)] - + [ - "--check-numerics", - f"--seed={RNG_SEED}", - f"--seq-length={SEQ_LENGTH}", - f"--batch-size={BATCH_SIZE}", - f"--num-heads={NUM_HEADS}", - f"--head-dim={HEAD_DIM}", - f"--comm-type={comm_type}", - ] - ) + test_cmd = LAUNCH_CMD + [ + str(test_path), + "--check-numerics", + f"--seed={RNG_SEED}", + f"--seq-length={SEQ_LENGTH}", + f"--batch-size={BATCH_SIZE}", + f"--num-heads={NUM_HEADS}", + f"--head-dim={HEAD_DIM}", + f"--comm-type={comm_type}", + ] if bulk: test_cmd.append("--bulk-overlap") else: - if fp8: + if fp8_in: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") + if fp8_out: + test_cmd.append("--fp8-output") if p2p: test_cmd.append("--p2p") if aggregate: @@ -101,5 +74,173 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk): pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.") test_cmd.append("--atomic") - output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False) - assert "NUMERICAL CHECK PASSED" in str(output) + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if ( + result.returncode != 0 + or "NUMERICAL CHECK FAILED" in result.stderr.decode() + or "NUMERICAL CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError(result.stderr.decode()) + + +def _run_layer_with_overlap(layer_type, fp8, fp8_init): + test_path = TEST_ROOT / "run_layer_with_overlap.py" + test_cmd = LAUNCH_CMD + [ + str(test_path), + f"--seed={RNG_SEED}", + f"--seq-length={SEQ_LENGTH}", + f"--batch-size={BATCH_SIZE}", + f"--num-heads={NUM_HEADS}", + f"--head-dim={HEAD_DIM}", + f"--layer-type={layer_type}", + ] + + if fp8: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + test_cmd.append("--fp8") + if fp8_init: + test_cmd.append("--fp8-init") + + os.environ["PYTORCH_JIT"] = "0" + os.environ["NVTE_TORCH_COMPILE"] = "0" + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + + os.unsetenv("PYTORCH_JIT") + os.unsetenv("NVTE_TORCH_COMPILE") + os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO") + + if ( + result.returncode != 0 + or "NUMERICAL CHECK FAILED" in result.stderr.decode() + or "NUMERICAL CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError(result.stderr.decode()) + + +@pytest.mark.parametrize( + "fp8,aggregate", + [ + (False, False), + (False, True), + (True, False), + (True, True), + ], + ids=[ + " BF16 IN - RING-EXCHANGE ", + " BF16 IN - RING-EXCHANGE - 2x AGGREGATED ", + " FP8 IN - RING-EXCHANGE ", + " FP8 IN - RING-EXCHANGE - 2x AGGREGATED ", + ], +) +def test_split_all_gather_overlaps(fp8, aggregate): + """ + Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or + te.cpp_extensions.fp8_gemm. + """ + _run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate) + + +@pytest.mark.parametrize( + "fp8_in,fp8_out,p2p", + [ + (False, False, False), + (False, False, True), + (True, False, False), + (True, False, True), + (True, True, False), + (True, True, True), + ], + ids=[ + " BF16 IN - BF16 OUT - PIPELINE ", + " BF16 IN - BF16 OUT - RING-EXCHANGE ", + " FP8 IN - BF16 OUT - PIPELINE ", + " FP8 IN - BF16 OUT - RING-EXCHANGE ", + " FP8 IN - FP8 OUT - PIPELINE ", + " FP8 IN - FP8 OUT - RING-EXCHANGE ", + ], +) +def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p): + """ + Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or + te.cpp_extensions.fp8_gemm. + """ + _run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False) + + +@pytest.mark.parametrize( + "ag_type,rs_type,p2p,fp8_out", + [ + (0, 0, False, False), + (0, 1, False, False), + (0, 1, False, True), + (0, 2, False, False), + (0, 2, False, True), + (0, 0, True, False), + (0, 0, True, True), + (1, 0, True, False), + (1, 0, True, True), + ], + ids=[ + " NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ", + " NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ", + " NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ", + " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ", + " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ", + " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", + " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", + " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", + " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", + ], +) +def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): + """ + Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with + direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. + """ + os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type) + os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type) + _run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False) + + +@pytest.mark.parametrize( + "comm_type,fp8", + [ + ("AG", False), + ("RS", False), + ("RS", True), + ], + ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "], +) +def test_bulk_overlaps(comm_type, fp8): + """ + Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. + """ + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + + +@pytest.mark.parametrize( + "layer_type", + [layer.__name__ for layer in TE_LAYERS], + ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS], +) +@pytest.mark.parametrize( + "fp8,fp8_init", + [ + (False, False), + (True, False), + (True, True), + ], + ids=[ + " BF16 GEMM - BF16 PARAMS ", + " FP8 GEMM - BF16 PARAMS ", + " FP8 GEMM - FP8 PARAMS ", + ], +) +def test_layers_with_overlap(layer_type, fp8, fp8_init): + """ + Test Transformer Engine layers with comm+GEMM overlap. + """ + _run_layer_with_overlap(layer_type, fp8, fp8_init) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 9b9b7686c2..6c775fb127 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -2,17 +2,23 @@ # # See LICENSE for license information. -import os, sys +import os, sys, logging +from contextlib import nullcontext import torch import torch.distributed as dist from transformer_engine.pytorch.attention import DotProductAttention +from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn +from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.common.recipe import DelayedScaling -dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} +dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} -def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"): +def run_dpa_with_cp( + dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p" +): """Test DotProductAttention module with context parallelism""" os.environ["NVTE_FLASH_ATTN"] = "0" @@ -23,10 +29,16 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" config = model_configs_fused_attn[model] - if qkv_format == "thd" and ( - config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias" - ): - return + + assert config.attn_mask_type in [ + "causal", + "no_mask", + ], f"{config.attn_mask_type} is an unsupported attention mask type!" + if kernel_backend == "FusedAttention" and qkv_format == "thd": + if "causal" in config.attn_mask_type: + config.attn_mask_type = "padding_causal" + else: + config.attn_mask_type = "padding" rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) @@ -48,71 +60,98 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - assert config.attn_mask_type in [ - "causal", - "no_mask", - ], f"{config.attn_mask_type} is an unsupported attention mask type!" - - if kernel_backend == "FusedAttention" and qkv_format == "thd": - if "causal" in config.attn_mask_type: - config.attn_mask_type = "padding_causal" - else: - config.attn_mask_type = "padding" + if dtype == "fp8": + fp8_recipe = DelayedScaling(fp8_dpa=True) # instantiate core attn module core_attn = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, qkv_format=qkv_format, attn_mask_type=config.attn_mask_type, + window_size=config.window_size, ) core_attn = core_attn.cuda() # create flash attn inputs if qkv_format == "bshd": - q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim) + q_input_shape = ( + config.batch_size, + config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + ) kv_input_shape = ( config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, - config.head_dim, + config.head_dim_qk, ) attn_output_shape = ( config.batch_size, config.max_seqlen_q, - config.num_heads * config.head_dim, + config.num_heads * config.head_dim_qk, ) cu_seqlens_q = None cu_seqlens_kv = None + cu_seqlens_q_padded = None + cu_seqlens_kv_padded = None elif qkv_format == "sbhd": - q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim) + q_input_shape = ( + config.max_seqlen_q, + config.batch_size, + config.num_heads, + config.head_dim_qk, + ) kv_input_shape = ( config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, - config.head_dim, + config.head_dim_qk, ) attn_output_shape = ( config.max_seqlen_q, config.batch_size, - config.num_heads * config.head_dim, + config.num_heads * config.head_dim_qk, ) cu_seqlens_q = None cu_seqlens_kv = None + cu_seqlens_q_padded = None + cu_seqlens_kv_padded = None elif qkv_format == "thd": - seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to( - torch.int32 + q_input_shape = ( + config.batch_size * config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + ) + kv_input_shape = ( + config.batch_size * config.max_seqlen_q, + config.num_gqa_groups, + config.head_dim_qk, + ) + attn_output_shape = ( + config.batch_size * config.max_seqlen_q, + config.num_heads * config.head_dim_qk, ) - seqlens_q = seqlens_q - seqlens_q % (world_size * 2) - cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)]) + seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) + seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2) + cu_seqlens_q_padded = torch.cat( + [ + torch.zeros([1], dtype=torch.int32), + seqlens_q_padded.cumsum(0, dtype=torch.int32), + torch.tensor([q_input_shape[0]], dtype=torch.int32), + ] + ).cuda() + if kernel_backend == "FlashAttention": + cu_seqlens_q = cu_seqlens_q_padded[:-1] + else: + cu_seqlens_q = torch.cat( + [torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)] + ).cuda() cu_seqlens_kv = cu_seqlens_q - q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim) - kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim) - attn_output_shape = (cu_seqlens_q[-1], config.num_heads * config.head_dim) - cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda() - cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda() + cu_seqlens_kv_padded = cu_seqlens_q_padded else: assert False, f"{qkv_format} is an unsupported qkv_format!" @@ -132,22 +171,33 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= for x in [q, k, v, dout] + ([] if bias is None else [bias]): dist.broadcast(x, 0, group=cp_comm_group) if qkv_format == "thd": - for x in [cu_seqlens_q, cu_seqlens_kv]: + for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]: dist.broadcast(x, 0, group=cp_comm_group) # run core_attn without CP for x in [q, k, v]: x.requires_grad = True - out = core_attn( - q, - k, - v, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - ) - out.backward(dout) + + if dtype == "fp8": + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + with fp8_context: + out = core_attn( + q, + k, + v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] + ), + ) + out.backward(dout) # run core_attn wit CP q_, k_, v_, dout_, *rest = [ @@ -171,12 +221,14 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] ] elif qkv_format == "thd": - seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank) - seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank) + seq_idx_q = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, q_.shape[0], world_size, rank + ) + seq_idx_kv = tex.thd_get_partitioned_indices( + cu_seqlens_kv_padded, k_.shape[0], world_size, rank + ) q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] - cu_seqlens_q = cu_seqlens_q // world_size - cu_seqlens_kv = cu_seqlens_kv // world_size else: assert False, f"{qkv_format} is an unsupported qkv_format!" q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] @@ -186,34 +238,37 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) - core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) - max_seqlen_q = config.max_seqlen_q - max_seqlen_kv = config.max_seqlen_kv - out_ = core_attn( - q_, - k_, - v_, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias_, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, + core_attn.set_context_parallel_group( + cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type ) - out_.backward(dout_) + + if dtype == "fp8": + core_attn.reset_fp8_meta_tensors() + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + with fp8_context: + out_ = core_attn( + q_, + k_, + v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] + ), + ) + out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: assert torch.all(~torch.isnan(x)) assert torch.all(~torch.isinf(x)) # compare results with and without CP - tols = dict(atol=5e-3, rtol=5e-3) - if dtype == "bf16": - if config.num_heads == config.num_gqa_groups: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=3.5e-2, rtol=3.5e-2) - if qkv_format == "bshd" or qkv_format == "sbhd": dq, dk, dv, out = [ x.view( @@ -230,38 +285,97 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend= for x in [q_.grad, k_.grad, v_.grad, out_] ] elif qkv_format == "thd": - dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]] - dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]] + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] + dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] + cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size + cu_seqlens_q = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + ) + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + for x in [dq, out, dq_, out_]: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]] + ).item() + == 0 + ) + cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size + cu_seqlens_kv = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + ) + cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv + num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] + for x in [dk, dv, dk_, dv_]: + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + (cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[ + b + 1 + ] + ] + ).item() + == 0 + ) else: assert False, f"{qkv_format} is an unsupported qkv_format!" + if dtype == "bf16": + if config.num_heads == config.num_gqa_groups: + tols = dict(atol=2.5e-2, rtol=2.5e-2) + else: + tols = dict(atol=3.5e-2, rtol=3.5e-2) + elif dtype == "fp16": + tols = dict(atol=5e-3, rtol=5e-3) + elif dtype == "fp8": + tols = dict(atol=5e-1, rtol=5e-1) + rmse_tol = 0.1 + else: + assert False, f"{dtype} is an unsupported dtype!" + + def _rmse(a, b): + return torch.sqrt((a - b).square().mean()).item() + + def _error(a, b): + if dtype != "fp8": + torch.testing.assert_close(a, b, **tols) + else: + try: + torch.testing.assert_close(a, b, **tols) + except Exception as e: + logging.debug(e) + + rmse = _rmse(a, b) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert ( + rmse < rmse_tol * rmse_range + ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + if qkv_format == "bshd": - torch.testing.assert_close(out_[:, 0], out[:, 0], **tols) - torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols) - torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols) - torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols) - torch.testing.assert_close(out_[:, 1], out[:, 1], **tols) - torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols) - torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols) - torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a[:, 0], b[:, 0]) + _error(a[:, 1], b[:, 1]) elif qkv_format == "sbhd": - torch.testing.assert_close(out_[0], out[0], **tols) - torch.testing.assert_close(dq_[0], dq[0], **tols) - torch.testing.assert_close(dk_[0], dk[0], **tols) - torch.testing.assert_close(dv_[0], dv[0], **tols) - torch.testing.assert_close(out_[1], out[1], **tols) - torch.testing.assert_close(dq_[1], dq[1], **tols) - torch.testing.assert_close(dk_[1], dk[1], **tols) - torch.testing.assert_close(dv_[1], dv[1], **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a[0], b[0]) + _error(a[1], b[1]) elif qkv_format == "thd": - torch.testing.assert_close(out_, out, **tols) - torch.testing.assert_close(dq_, dq, **tols) - torch.testing.assert_close(dk_, dk, **tols) - torch.testing.assert_close(dv_, dv, **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a, b) else: assert False, f"{qkv_format} is an unsupported qkv_format!" + dist.destroy_process_group() + def main(**kwargs): run_dpa_with_cp(**kwargs) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 73dfa23d9a..82a3c8576b 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -8,6 +8,7 @@ import os from importlib.metadata import version from typing import Any, Dict, List, Tuple, Union, Optional +from contextlib import contextmanager import pytest import torch @@ -77,12 +78,13 @@ def __init__( batch_size: int, num_heads: int, num_gqa_groups: int, - head_dim: int, + head_dim_qk: int, max_seqlen_q: int, max_seqlen_kv: int, dropout_p: float, attn_mask_type: str, attn_bias_type: str, + head_dim_v: int = None, alibi_type: str = "none", num_layers: int = 1, bias_shape: str = "1hss", @@ -91,9 +93,10 @@ def __init__( self.batch_size = batch_size self.num_heads = num_heads self.num_gqa_groups = num_gqa_groups - self.head_dim = head_dim - self.hidden_size = num_heads * head_dim - self.hidden_size_kv = num_gqa_groups * head_dim + self.head_dim_qk = head_dim_qk + self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v + self.hidden_size = num_heads * head_dim_qk + self.hidden_size_kv = num_gqa_groups * self.head_dim_v self.max_seqlen_q = max_seqlen_q self.max_seqlen_kv = max_seqlen_kv self.dropout_p = dropout_p @@ -106,6 +109,16 @@ def __init__( self.window_size = window_size +@contextmanager +def logging_context(highest_level=logging.WARNING): + previous_level = logging.root.manager.disable + logging.disable(highest_level) + try: + yield + finally: + logging.disable(previous_level) + + def _get_attention_backends( config: ModelConfig, qkv_dtype: torch.dtype, @@ -137,7 +150,11 @@ def _get_attention_backends( ) core_attention_bias_requires_grad = False # d=256 is supported by cuDNN 9.0+ for inference but not training - if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128: + if ( + config.attn_bias_type == "post_scale_bias" + and config.head_dim_qk <= 128 + and config.head_dim_v <= 128 + ): core_attention_bias_requires_grad = True fused_attn_backends = [] @@ -153,7 +170,8 @@ def test(): num_gqa_groups=config.num_gqa_groups, max_seqlen_q=config.max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, - head_dim=config.head_dim, + head_dim_qk=config.head_dim_qk, + head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, window_size=window_size, alibi_slopes_shape=alibi_slopes_shape, @@ -173,12 +191,13 @@ def test(): return available_backends, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + with logging_context(): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, fused_attn_backends @@ -218,11 +237,12 @@ def test_dot_product_attention( if dtype == torch.bfloat16: tols = dict(atol=2.5e-2, rtol=2.5e-2) config = model_configs[model] + is_mla = config.head_dim_qk != config.head_dim_v if qkv_layout is None: if config.attn_type == "self": - qkv_layout = "sb3hd" + qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd" else: - qkv_layout = "sbhd_sb2hd" + qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd" if "3" in qkv_layout and config.attn_type == "cross": pytest.skip("No need to test this layout for cross attention") @@ -241,14 +261,17 @@ def test_dot_product_attention( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes - if pad_between_seqs: + if pad_between_seqs and not ( + config.max_seqlen_q != config.max_seqlen_kv + and config.attn_mask_type in ["causal", "padding_causal"] + ): flash_attn_supported = True # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") - is_training = config.head_dim <= 128 + is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 # UnfusedDotProductAttention backend if unfused_attn_supported: unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( @@ -343,6 +366,38 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) +model_configs_mla = { + # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend + "mla_1_0": ModelConfig( + 8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # self , 0 + "mla_1_1": ModelConfig( + 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # cross, 0 + "mla_2_0": ModelConfig( + 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 + ), # self , 1 + "mla_2_1": ModelConfig( + 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 + ), # cross, 1 + "mla_3_0": ModelConfig( + 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 + ), # inference + "mla_3_1": ModelConfig( + 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # inference +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_mla]) +@pytest.mark.parametrize("model", model_configs_mla.keys()) +def test_dpa_mla(dtype, model_configs, model): + """Test DotProductAttention module with Multi-Latent Attention (MLA)""" + test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) + + model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), @@ -586,14 +641,16 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): @pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" - pad_between_seqs = False - test_dot_product_attention( - dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs - ) pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) + if get_cudnn_version() >= (9, 3, 0): + # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run + pad_between_seqs = False + test_dot_product_attention( + dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs + ) def _run_dot_product_attention( @@ -736,7 +793,8 @@ def _run_dot_product_attention( "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q_after_pad[-1], "tg": cu_seqlens_kv_after_pad[-1], "3": 3, @@ -753,12 +811,16 @@ def _run_dot_product_attention( layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor_orig = tensor if qkv_format == "thd" and pad_between_seqs: tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]: + if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]: for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_q_after_pad[i - 1], @@ -772,7 +834,7 @@ def _run_dot_product_attention( tensor_orig = torch.cat( [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0 ) - if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]: + if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]: for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_kv_after_pad[i - 1], @@ -811,13 +873,14 @@ def _run_dot_product_attention( # Create output gradient qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda") out_grad_orig = out_grad if qkv_format == "thd" and pad_between_seqs: out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - if qkv_format_kv == "t_h_d": + if qkv_format_kv == "t_h_dv": for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_q_after_pad[i - 1], @@ -851,7 +914,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: # Set up model block = DotProductAttention( config.num_heads, - config.head_dim, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, qkv_format=qkv_format, @@ -906,9 +969,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + if is_training: + q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) for i in range(1, config.batch_size + 1): valid_range_q = ( cu_seqlens_q_after_pad[i - 1], @@ -919,15 +983,16 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: cu_seqlens_kv_after_pad[i] - pad_len[i - 1], ) out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0) - q_grad_orig = torch.cat( - [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0 - ) - k_grad_orig = torch.cat( - [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 - ) - v_grad_orig = torch.cat( - [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 - ) + if is_training: + q_grad_orig = torch.cat( + [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0 + ) + k_grad_orig = torch.cat( + [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 + ) + v_grad_orig = torch.cat( + [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 + ) if is_training: return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) else: @@ -1168,7 +1233,7 @@ def _run_transformer_layer( # Create RoPE rotary_pos_emb = None if RoPE: - PE = RotaryPositionEmbedding(dim=config.head_dim) + PE = RotaryPositionEmbedding(dim=config.head_dim_qk) rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") # Set up model @@ -1183,7 +1248,7 @@ def _run_transformer_layer( init_method=init_method, output_layer_init_method=output_layer_init_method, layer_number=layer_number, - kv_channels=config.head_dim, + kv_channels=config.head_dim_qk, self_attn_mask_type=config.attn_mask_type, tp_group=None, tp_size=1, @@ -1356,7 +1421,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: mha = MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, - kv_channels=config.head_dim, + kv_channels=config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, layer_number=1, @@ -1387,7 +1452,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "d": config.head_dim_qk, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -1531,7 +1596,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with fp8_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -1560,7 +1625,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "d": config.head_dim_qk, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -1732,7 +1797,7 @@ def _run_custom_mha_fp8(dtype, config, backend): inp = 0.0001 * torch.randint( -100, 100, - (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim), + (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk), dtype=dtype, device="cuda", requires_grad=True, @@ -1743,7 +1808,7 @@ def _run_custom_mha_fp8(dtype, config, backend): out_grad = 0.01 * torch.randn( config.batch_size * config.max_seqlen_q, - config.num_heads * config.head_dim, + config.num_heads * config.head_dim_qk, dtype=dtype, device="cuda", ) @@ -1766,7 +1831,7 @@ def _run_custom_mha_fp8(dtype, config, backend): return ( out.view(config.batch_size, config.max_seqlen_q, -1), dqkv.view( - config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim + config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk ).contiguous(), ) @@ -1809,7 +1874,7 @@ def get_dummy_cuda_rng_tracker(): block = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, attention_dropout=config.dropout_p, sequence_parallel=False, tp_size=1, @@ -2105,7 +2170,7 @@ def __init__(self, config, params_dtype: torch.dtype = torch.float32): self.p_dropout = config.dropout_p self.h = config.num_heads self.hidden_size = config.hidden_size - self.head_dim = config.head_dim + self.head_dim = config.head_dim_qk self.fast_zero_fill = True self.mask_type = config.attn_mask_type diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 31a653b505..82875e2791 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -16,11 +16,17 @@ ) model_configs_flash_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: b, h, hg, d, sq, skv, p, mask, bias "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA + "cp_1_2": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # MHA "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_2_2": ModelConfig( + 2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # GQA } @@ -39,7 +45,28 @@ def get_bash_arguments(**kwargs): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -def test_cp_with_flash_attention(dtype, model, qkv_format): +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): + config = model_configs_flash_attn[model] + if cp_comm_type == "all_gather" and qkv_format == "thd": + pytest.skip( + f"CP implementation with KV all-gather does not support {qkv_format} format yet!" + ) + if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: + pytest.skip( + f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" + " type yet!" + ) + if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip( + f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" + " type yet!" + ) + if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip( + f"CP implementation with KV P2P does not support window size {config.window_size} yet!" + ) + subprocess.run( get_bash_arguments( dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention" @@ -49,7 +76,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format): model_configs_fused_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: b, h, hg, d, sq, skv, p, mask, bias "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA @@ -63,12 +90,48 @@ def test_cp_with_flash_attention(dtype, model, qkv_format): @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -def test_cp_with_fused_attention(dtype, model, qkv_format): +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+.") + if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): + pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0") + + config = model_configs_fused_attn[model] + if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: + pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!") + if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": + pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!") + if cp_comm_type == "all_gather" and qkv_format == "thd": + pytest.skip( + f"CP implementation with KV all-gather does not support {qkv_format} format yet!" + ) + if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: + pytest.skip( + f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" + " type yet!" + ) + if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip( + f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" + " type yet!" + ) + if config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip( + "Fused attention does not support sliding window attention + context parallelism yet!" + ) + if cp_comm_type == "all_gather" and dtype == "fp8": + pytest.skip( + "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" + ) + if dtype == "fp8" and qkv_format == "thd": + pytest.skip("FP8 attention cannot work with THD format yet!") + if dtype == "fp8" and config.attn_bias_type != "no_bias": + pytest.skip("FP8 attention cannot work with bias yet!") + subprocess.run( get_bash_arguments( dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 8d3a9dca4f..60a5a1ea99 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -3,7 +3,8 @@ # See LICENSE for license information. from dataclasses import dataclass -from typing import List, Tuple +import itertools +from typing import Iterable, List, Tuple, Union import pytest import torch @@ -88,7 +89,7 @@ def generate_data( dpa: bool = False, warmup: bool = False, return_grad_output: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[List[torch.Tensor], torch.Tensor]: """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn if dpa: @@ -129,14 +130,20 @@ def generate_data( return inputs, grad_output -def get_outputs(model, output): +def get_outputs( + model: torch.nn.Module, + output: Union[torch.Tensor, Iterable[torch.Tensor]], +) -> List[torch.Tensor]: """Return grads and params for comparsion.""" values = [] for param in model.parameters(): values.append(param) if param.grad is not None: values.append(param.grad) - values.append(output) + if isinstance(output, torch.Tensor): + values.append(output) + else: + values.extend(output) return values @@ -161,7 +168,7 @@ def _test_cuda_graphs( module: str, graph_mode: str, ) -> List[torch.Tensor]: - """Helper function for test.""" + """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() dpa = module == "dpa" @@ -247,7 +254,7 @@ def _test_cuda_graphs( else: model = modules[0] if dpa else _Sequential(*modules) - # Loss function and optimizer. + # Optimizer. if not dpa: optimizer = torch.optim.SGD(model.parameters(), lr=0.001) @@ -312,3 +319,193 @@ def test_gpt_make_graphed_callables( # Check that results match assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode2) + + +def _test_cuda_graphs_with_kwargs( + *, + config: ModelConfig, + dtype: torch.dtype, + with_graph: bool, +) -> List[torch.Tensor]: + """Simulate Megatron-LM interleaved pipeline parallelism.""" + reset_rng_states() + + # Initialize model. + model = TransformerLayer( + config.hidden_size, + config.hidden_size, + config.num_heads, + hidden_dropout=0.0, + attention_dropout=0.0, + self_attn_mask_type="arbitrary", + fuse_qkv_params=True, + params_dtype=dtype, + ) + + # Initialize gradient buffers. + for param in model.parameters(): + param.grad = torch.empty_like(param) + + # Make graphed version of model if needed. + if with_graph: + attn_mask = torch.zeros( + (config.batch_size, 1, config.sequence_length, config.sequence_length), + dtype=torch.bool, + device="cuda", + ) + model = make_graphed_callables( + model, + generate_data(config, dtype, warmup=True), + sample_kwargs=dict(attention_mask=attn_mask), + allow_unused_input=True, + ) + + # Optimizer. + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + + # Training loop. + for _ in range(3): + optimizer.zero_grad(set_to_none=False) + for grad_accumulation_step in range(2): + inputs, grad_output = generate_data(config, dtype, return_grad_output=True) + attn_mask = torch.randint( + 2, + (config.batch_size, 1, config.sequence_length, config.sequence_length), + dtype=torch.bool, + device="cuda", + ) + output = model(*inputs, attention_mask=attn_mask) + output.backward(grad_output) + optimizer.step() + + return get_outputs(model, output) + + +def test_make_graphed_callables_with_kwargs( + dtype: torch.dtype = torch.float32, + model: str = "small", +) -> None: + """Test CUDA graphs with keyword arguments.""" + config = model_configs[model] + kwargs = dict(config=config, dtype=dtype) + outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs) + graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs) + assert_all_equal(outputs, graph_outputs) + + +def _test_cuda_graphs_with_interleaved_pipeline_parallelism( + *, + config: ModelConfig, + dtype: torch.dtype, + with_graph: bool, +) -> List[torch.Tensor]: + """Simulate Megatron-LM interleaved pipeline parallelism.""" + reset_rng_states() + + # Pipeline parallel configuration. + num_layers = 2 + num_microbatches = 3 + layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1] + + # Initialize model. + model = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + config.hidden_size, + params_dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + # Initialize gradient buffers. + for param in model.parameters(): + param.grad = torch.empty_like(param) + + # Make graphed version of model if needed. + layer_forwards = { + (i % num_layers, i // num_layers): model[i % num_layers] + for i in range(num_layers * num_microbatches) + } + if with_graph: + sample_args = tuple( + generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches) + ) + layer_forwards = make_graphed_callables( + tuple(model), + sample_args, + allow_unused_input=True, + _order=layer_order, + ) + layer_forwards = { + (i // num_microbatches, i % num_microbatches): forward + for i, forward in enumerate(layer_forwards) + } + + # Optimizer. + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + + # Training loop. + for _ in range(3): + optimizer.zero_grad(set_to_none=False) + + # Generate data. + inputs = {} + grad_outputs = {} + for layer_idx in range(num_layers): + for microbatch_idx in range(num_microbatches): + x, dy = generate_data(config, dtype, return_grad_output=True) + idxs = (layer_idx, microbatch_idx) + inputs[idxs] = x[0] + grad_outputs[idxs] = dy + + # Cache for layer outputs. + outputs = {} + + def forward(layer_idx: int, microbatch_idx: int): + """Helper function for forward steps""" + idxs = (layer_idx, microbatch_idx) + outputs[idxs] = layer_forwards[idxs](inputs[idxs]) + + def backward(layer_idx: int, microbatch_idx: int): + """Helper function for backward steps""" + outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx]) + + # Forward and backward steps. + forward(0, 0) + forward(1, 0) + forward(0, 1) + forward(1, 1) + backward(1, 0) + backward(0, 0) + forward(0, 2) + forward(1, 2) + backward(1, 1) + backward(0, 1) + backward(1, 2) + backward(0, 2) + + # Optimizer step. + optimizer.step() + + outputs = [y for _, y in sorted(outputs.items())] + return get_outputs(model, outputs) + + +def test_make_graphed_callables_with_interleaved_pipeline_parallelism( + dtype: torch.dtype = torch.float16, + model: str = "small", +) -> None: + """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" + config = model_configs[model] + kwargs = dict(config=config, dtype=dtype) + outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=False, + **kwargs, + ) + graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=True, + **kwargs, + ) + assert_all_equal(outputs, graph_outputs) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 8a50648391..ee6739fbf6 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -10,6 +10,13 @@ from torch import nn from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te +from transformer_engine.pytorch.attention import MultiheadAttention +from transformer_engine.pytorch import fp8_model_init +from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() class TestFusedOptimizer(unittest.TestCase): @@ -169,6 +176,83 @@ def test_frozen_model(self): torch.testing.assert_close(ref_param, tst_param) + @unittest.skipIf(not is_bf16_compatible(), "bf16 if not supported") + def test_bf16_model_weight_cast(self): + dtype = torch.bfloat16 + model = MultiheadAttention( + hidden_size=1024, + num_attention_heads=16, + layer_number=1, + params_dtype=dtype, + fuse_qkv_params=True, + ).cuda() + ref_params = [] + master_params = [] + model_params = [] + for p in model.parameters(): + if p.requires_grad: + ref_params.append(p.detach().clone().float()) + master_params.append(p.detach().clone().float()) + model_params.append(p) + options = { + "lr": 5e-4, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "amsgrad": False, + } + ref_optim = torch.optim.Adam(ref_params, **options) + tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) + + for i in range(self.iters): + self.gen_grad(ref_params, master_params) + ref_optim.step() + tst_optim.step() + torch.testing.assert_close(ref_params, master_params) + model_params_to_fp32 = [p.float() for p in model_params] + torch.testing.assert_close( + ref_params, model_params_to_fp32, rtol=1e-3, atol=1e-3, equal_nan=True + ) + + @unittest.skipIf(not fp8_available, reason=reason_for_no_fp8) + def test_fp8_model_weight_cast(self): + dtype = torch.bfloat16 + with fp8_model_init(enabled=True): + model = MultiheadAttention( + hidden_size=1024, + num_attention_heads=16, + layer_number=1, + params_dtype=dtype, + fuse_qkv_params=True, + ).cuda() + ref_params = [] + master_params = [] + model_params = [] + for p in model.parameters(): + if p.requires_grad: + ref_params.append(p.detach().clone().float()) + master_params.append(p.detach().clone().float()) + model_params.append(p) + options = { + "lr": 5e-4, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "amsgrad": False, + } + ref_optim = torch.optim.Adam(ref_params, **options) + tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) + + for i in range(self.iters): + self.gen_grad(ref_params, master_params) + ref_optim.step() + tst_optim.step() + torch.testing.assert_close(ref_params, master_params) + model_params_to_fp32 = [p.float() for p in model_params] + torch.testing.assert_close( + ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True + ) + class TestFusedSGD(TestFusedOptimizer): def __init__(self, *args, **kwargs): @@ -345,8 +429,9 @@ def testGradScalerCapturableMaster(self): if m.__class__ in [torch.nn.Conv2d]: m.half() params_ = [p for p in self.model_.parameters() if p.requires_grad] + master_weights = [p.float() for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam( - params_, lr=self.lr, capturable=True, master_weights=True + params_, lr=self.lr, capturable=True, master_weights=master_weights ) scaler = torch.cuda.amp.GradScaler(enabled=True) scaler_ = torch.cuda.amp.GradScaler(enabled=True) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9aab3b2702..3523e1cda5 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -15,8 +15,10 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor -from transformer_engine.pytorch.ops.fused_forward import ( +from transformer_engine.pytorch.ops.fused import ( + BackwardLinearAdd, ForwardLinearBiasActivation, + ForwardLinearBiasAdd, ) from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex @@ -84,15 +86,14 @@ def make_reference_and_test_tensors( """ ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(ref) + test = Float8Tensor.to_float8(test) test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1) test._transpose = test._transpose.contiguous() test._transpose_invalid = False - else: - test = ref.to(device=test_device, dtype=test_dtype) - if test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif test.data_ptr() == ref.data_ptr(): + test = test.clone() ref.copy_(test) ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) @@ -320,14 +321,13 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("in_shape", ((1,),)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("fp8", (False, True)) def test_identity( self, *, - in_shape: Iterable[int], + in_shape: Iterable[int] = (1,), dtype: torch.dtype, device: torch.device, fp8: bool, @@ -737,6 +737,123 @@ def test_linear( db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", ("cuda", "cpu")) + @pytest.mark.parametrize("fp8", (False, True)) + def test_add_in_place( + self, + *, + in_shape: Iterable[int] = (1,), + dtype: torch.dtype, + device: torch.device, + fp8: bool, + ) -> None: + + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x1_ref, x1_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = x2_ref.detach() + y_ref += x1_ref + dx1_ref = dy_ref + dx2_ref = dy_ref + + # Implementation with fusible operation + op = te_ops.AddInPlace() + y_test = op(x1_test, x2_test) + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + if fp8: + tols = dtype_tols(x1_test._fp8_dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") + dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0) + torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", ("cuda", "cpu")) + @pytest.mark.parametrize("fp8", (False, True)) + def test_make_extra_output( + self, + *, + in_shape: Iterable[int] = (1,), + dtype: torch.dtype, + device: torch.device, + fp8: bool, + ) -> None: + + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8, + ) + dy1_ref, dy1_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + dy2_ref, dy2_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y1_ref = x_ref + y2_ref = x_ref + (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() + + # Implementation with fusible operation + op = te_ops.MakeExtraOutput() + y1_test, y2_test = op(x_test) + (y1_test * dy1_test + y2_test * dy2_test).sum().backward() + + # Check results + tols = dtype_tols(dtype) + y1_test = y1_test.to(dtype=torch.float64, device="cpu") + y2_test = y2_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y1_test, y1_ref, rtol=0, atol=0) + torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + class TestFusedOps: """Tests for fused operations""" @@ -754,7 +871,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("fp8_compute", (False, True)) @pytest.mark.parametrize("fp8_input", (False, True)) @pytest.mark.parametrize("fp8_weight", (False, True)) - def test_linear_bias_activation( + def test_forward_linear_bias_activation( self, *, bias: bool = True, @@ -766,7 +883,7 @@ def test_linear_bias_activation( fp8_input: bool, fp8_weight: bool, ) -> None: - """GEMM + bias + activation""" + """Forward GEMM + bias + activation""" # Make input and weight shapes consistent out_features, in_features = weight_shape @@ -951,3 +1068,247 @@ def test_fp8_linear( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw0_test, w0_ref.grad, **tols) torch.testing.assert_close(dw1_test, w1_ref.grad, **tols) + + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + def test_forward_linear_bias_add( + self, + *, + bias: bool, + weight_shape: tuple[int, int] = (16, 16), + in_shape: Iterable[int] = (16, -1), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_compute: bool, + fp8_input: bool = False, + fp8_weight: bool = False, + fp8_output: bool = False, + ) -> None: + """Forward GEMM + bias + add""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + if fp8_input or fp8_weight or fp8_output or fp8_compute: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + if fp8_compute: + if ( + math.prod(in_shape[:-1]) % 16 != 0 + or in_features % 16 != 0 + or out_features % 16 != 0 + ): + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + if fp8_output and not fp8_compute: + pytest.skip("FP8 output requires FP8 compute") + if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") + + # Random data + x1_ref, x1_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_output, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x1_ref, w_ref, bias=b_ref) + x2_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operations + with te.fp8_model_init(enabled=fp8_weight): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ), + te_ops.AddInPlace(), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + if bias: + model[0].bias.copy_(b_test) + del w_test + del b_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = model(x1_test, x2_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], ForwardLinearBiasAdd) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[0].weight._fp8_dtype + if is_float8_tensor(model[0].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") + dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx1_test, x1_ref.grad, **tols) + torch.testing.assert_close(dx2_test, x2_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + if bias: + db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, b_ref.grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + def test_backward_linear_add( + self, + *, + weight_shape: tuple[int, int] = (16, 16), + in_shape: Iterable[int] = (16, -1), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_compute: bool, + fp8_input: bool = False, + fp8_weight: bool = False, + fp8_output: bool = False, + ) -> None: + """Backward dgrad GEMM + add""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + if fp8_input or fp8_weight or fp8_output or fp8_compute: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + if fp8_compute: + if ( + math.prod(in_shape[:-1]) % 16 != 0 + or in_features % 16 != 0 + or out_features % 16 != 0 + ): + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + if fp8_output and not fp8_compute: + pytest.skip("FP8 output requires FP8 compute") + if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_input), + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=(fp8_compute or fp8_weight), + ) + dy1_ref, dy1_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + dy2_ref, dy2_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y1_ref = torch.nn.functional.linear(x_ref, w_ref) + y2_ref = x_ref + (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() + + # Implementation with fusible operations + with te.fp8_model_init(enabled=fp8_weight): + model = te_ops.Sequential( + te_ops.MakeExtraOutput(), + te_ops.Linear( + in_features, + out_features, + bias=False, + device=device, + dtype=dtype, + ), + ) + with torch.no_grad(): + model[1].weight.copy_(w_test) + del w_test + with te.fp8_autocast(enabled=fp8_compute): + y1_test, y2_test = model(x_test) + (y1_test * dy1_test + y2_test * dy2_test).sum().backward() + + # Check that backward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], BackwardLinearAdd) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[1].weight._fp8_dtype + if is_float8_tensor(model[1].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y1_test = y1_test.to(dtype=torch.float64, device="cpu") + y2_test = y2_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y1_test, y1_ref, **tols) + torch.testing.assert_close(y2_test, y2_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 7eed97a0ca..85cd4fc256 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -34,11 +34,13 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace +from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +sm_80plus = get_device_compute_capability() >= (8, 0) seed = 1234 torch.manual_seed(seed) @@ -1228,7 +1230,8 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False inp_hidden_states.retain_grad() m = config.seq_len // 16 - dist = torch.sort(torch.randint(0, m, (num_gemms - 1,))).values.tolist() + dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() + dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) m_splits = m_splits * 16 assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms @@ -1547,8 +1550,29 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): attn_input_format="bshd", ) - for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): - assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" + torch.manual_seed(0) + block_thd = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", + attn_input_format="thd", + self_attn_mask_type="padding_causal", + ) + + for (n1, p1), (n2, p2), (n3, p3) in zip( + block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters() + ): + assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" x_sbhd = torch.randn( (config.seq_len, bs, config.hidden_size), @@ -1558,6 +1582,8 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ) x_bshd = x_sbhd.transpose(0, 1).contiguous() + x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() + x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -1575,6 +1601,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): y_sbhd.transpose(0, 1).contiguous(), ) + # THD is not supported in float32 and on GPUs older than Ampere, skip the test here + if dtype != torch.float32 and sm_80plus: + # To make sure forward is also identical (just in case some module decides + # to act fancy) + torch.manual_seed(0) + y_thd = block_thd( + x_thd, + cu_seqlens_q=x_thd_cumsum, + cu_seqlens_kv=x_thd_cumsum, + max_seqlen_q=config.seq_len, + max_seqlen_kv=config.seq_len, + ) + + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + ) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @@ -1611,8 +1655,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ffn_hidden_size=4 * D, num_attention_heads=H, attn_input_format=input_format, - self_attn_mask_type="causal_bottom_right", - enc_dec_attn_mask_type="causal_bottom_right", + self_attn_mask_type="causal", + enc_dec_attn_mask_type="causal", layer_number=layer_number, attention_dropout=0.0, params_dtype=dtype, @@ -1626,7 +1670,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, qkv_format=input_format, layer_number=layer_number, attention_dropout=0.0, - attn_mask_type="causal_bottom_right", + attn_mask_type="causal", params_dtype=dtype, ) .cuda() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b814ef5974..58bd4f828c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -4,23 +4,31 @@ cmake_minimum_required(VERSION 3.21) +# Language options if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) endif() - set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) - -project(transformer_engine LANGUAGES CUDA CXX) - if (CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") endif() +# Hide non-necessary symbols in shared object. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") + +# Transformer Engine library +project(transformer_engine LANGUAGES CUDA CXX) + +# CUDA Toolkit find_package(CUDAToolkit REQUIRED) +if (CUDAToolkit_VERSION VERSION_LESS 12.0) + message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") +endif() -# Check for cuDNN frontend API +# cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") @@ -31,10 +39,11 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") endif() include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +# Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) -include_directories(${PROJECT_SOURCE_DIR}/..) # Configure Transformer Engine library +include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES pycudnn.cpp @@ -73,8 +82,6 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") -target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) - # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas @@ -84,7 +91,10 @@ target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") -# Make header files with C++ strings +# Hack to enable dynamic loading in cuDNN frontend +target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) + +# Helper functions to make header files with C++ strings function(make_string_header STRING STRING_NAME) configure_file(util/string_header.h.in "string_headers/${STRING_NAME}.h" @@ -96,10 +106,11 @@ function(make_string_header_from_file file_ STRING_NAME) "string_headers/${STRING_NAME}.h" @ONLY) endfunction() + +# Header files with C++ strings list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path) make_string_header("${cuda_include_path}" string_path_cuda_include) - make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu string_code_transpose_rtc_cast_transpose_fusion_cu) make_string_header_from_file(transpose/rtc/cast_transpose.cu @@ -110,7 +121,6 @@ make_string_header_from_file(utils.cuh string_code_utils_cuh) make_string_header_from_file(util/math.h string_code_util_math_h) - target_include_directories(transformer_engine PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/string_headers") @@ -123,6 +133,23 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") +# Number of parallel build jobs +if(ENV{MAX_JOBS}) + set(BUILD_JOBS_STR "$ENV{MAX_JOBS}") +elseif(ENV{NVTE_BUILD_MAX_JOBS}) + set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}") +else() + set(BUILD_JOBS_STR "max") +endif() +message(STATUS "Parallel build jobs: ${BUILD_JOBS_STR}") + +# Number of threads per parallel build job +set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB}) +if (NOT BUILD_THREADS_PER_JOB) + set(BUILD_THREADS_PER_JOB 1) +endif() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") +message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") + # Install library install(TARGETS transformer_engine DESTINATION .) - diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index f4eb2c419f..46cfa9176a 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -4,6 +4,7 @@ """FW agnostic user-end APIs""" +import sys import glob import sysconfig import subprocess @@ -15,6 +16,16 @@ import transformer_engine +def is_package_installed(package): + """Checks if a pip package is installed.""" + return ( + subprocess.run( + [sys.executable, "-m", "pip", "show", package], capture_output=True, check=False + ).returncode + == 0 + ) + + def get_te_path(): """Find Transformer Engine install path using pip""" return Path(transformer_engine.__path__[0]).parent diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 895baea789..70f1fa409f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -72,8 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, - int64_t window_size_right) { + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -84,10 +84,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && - (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || + (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) && + (head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && - (max_seqlen_kv % 128 == 0) && (head_dim == 128) && + (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -104,8 +104,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool flag_m512 = false; bool flag_arb = false; if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && - (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim == 64) && - (num_attn_heads == num_gqa_groups) && + (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && + (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -131,18 +131,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) && // head dimension - ((head_dim <= 128 && head_dim % 8 == 0) || + ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || // TODO (cyang): add is_training to nvte_get_fused_attn_backend // d=256 only supported for forward - (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 && - head_dim % 8 == 0)) && + (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && + head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version >= 8906) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) || + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || ((cudnn_runtime_version >= 90000) && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && @@ -155,6 +158,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || ((cudnn_runtime_version >= 90300) && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && @@ -259,7 +263,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, window_size_left, window_size_right); + max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -336,7 +340,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, window_size_left, window_size_right); + max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -430,7 +434,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -514,7 +518,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -595,7 +599,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -603,13 +608,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, + input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); @@ -617,18 +622,18 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -674,7 +679,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -682,15 +688,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_K, input_V, input_dO, output_S, - output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, - input_cu_seqlens_kv, wkspace, stream, handle); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -705,9 +711,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_K, - input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, + input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -721,7 +727,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 525fd3330d..42fb779717 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -48,11 +48,11 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, - int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, + float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, @@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( hg, s_q, s_kv, - d, + d_qk, + d_v, bias_b, bias_h, scaling_factor, @@ -167,41 +168,41 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); if (is_ragged) { Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v)); } else { Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride)); } @@ -265,15 +266,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { O->set_output(true) - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o); } else { - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); } Stats->set_output(true) @@ -360,7 +361,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsO = static_cast(devOffsetsV) + (b + 1) * sizeof(int32_t); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d, static_cast(devPtrSeqOffsetsQ), + layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), static_cast(devOffsetsQ), static_cast(devOffsetsK), static_cast(devOffsetsV), static_cast(devOffsetsO)); @@ -381,13 +382,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } void fused_attn_arbitrary_seqlen_bwd_impl( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, - int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, + void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -419,7 +420,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( hg, s_q, s_kv, - d, + d_qk, + d_v, bias_b, bias_h, scaling_factor, @@ -505,61 +507,61 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q)); k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k)); v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v)); o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o)); } else { q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride)); k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride)); v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride)); o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); } stats = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -586,7 +588,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_sliding_window_length(window_size_left); } - if (cudnn_runtime_version >= 90000 && sm_arch_ >= 90) { + if (cudnn_runtime_version >= 90000) { sdpa_backward_options.set_deterministic_algorithm(deterministic); } @@ -644,21 +646,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged) { dQ->set_output(true) - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q); dK->set_output(true) - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k); dV->set_output(true) - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v); } else { - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride); } std::tuple, // q @@ -758,7 +760,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsO = static_cast(devOffsetsV) + (b + 1) * sizeof(int32_t); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d, static_cast(devPtrSeqOffsetsQ), + layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), static_cast(devOffsetsQ), static_cast(devOffsetsK), static_cast(devOffsetsV), static_cast(devOffsetsO)); @@ -865,11 +867,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, + bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, + devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, + handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -941,11 +944,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, + bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { @@ -1051,12 +1054,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, + bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1131,12 +1134,13 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, + bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1155,8 +1159,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -1233,12 +1237,12 @@ void fused_attn_arbitrary_seqlen_fwd( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1257,7 +1261,7 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, @@ -1302,12 +1306,13 @@ void fused_attn_arbitrary_seqlen_bwd( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 27a2dd37ea..4b523cca1a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -58,8 +58,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -68,7 +68,7 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fcce30d6a1..fb7765e1a4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1679,6 +1679,7 @@ void fused_attn_fp8_fwd_impl_v1( s_q, s_kv, d, + d, bias_b, bias_h, scaling_factor, @@ -1834,8 +1835,14 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); - amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_o->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); Stats->set_output(true) .set_data_type(fe::DataType_t::FLOAT) @@ -1976,6 +1983,7 @@ void fused_attn_fp8_bwd_impl_v1( s_q, s_kv, d, + d, bias_b, bias_h, scaling_factor, @@ -2180,10 +2188,22 @@ void fused_attn_fp8_bwd_impl_v1( dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); - amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dQ->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dK->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dV->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); dO->set_data_type(bwd_tensor_type); dQ->set_data_type(bwd_tensor_type); diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 7467462d2a..56dbb278b4 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -363,29 +363,30 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu // convert cu_seqlens_padded to offsets __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, - size_t hg, size_t d, int32_t *cu_seqlens_q_padded, + size_t hg, size_t d_qk, size_t d_v, + int32_t *cu_seqlens_q_padded, int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, int32_t *offsets_k, int32_t *offsets_v, int32_t *offsets_o) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < b + 1) { - offsets_o[tid] = h * d * cu_seqlens_q_padded[tid]; + offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid]; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - offsets_q[tid] = h * d * cu_seqlens_q_padded[tid]; - offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = offsets_k[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; + offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid]; + offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid]; break; case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_H3D: - offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid]; + offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid]; offsets_k[tid] = offsets_q[tid]; offsets_v[tid] = offsets_q[tid]; break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - offsets_q[tid] = h * d * cu_seqlens_q_padded[tid]; - offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; + offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid]; offsets_v[tid] = offsets_k[tid]; break; } diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 74d1628a33..d5cf450181 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -91,7 +91,8 @@ struct FADescriptor_v1 { std::int64_t hg; std::int64_t s_q; std::int64_t s_kv; - std::int64_t d; + std::int64_t d_qk; + std::int64_t d_v; std::int64_t bias_b; std::int64_t bias_h; float attnScale; @@ -107,11 +108,11 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining, + return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, + rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); } @@ -126,7 +127,8 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu int32_t *kv_seqlens); __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, - size_t hg, size_t d, int32_t *cu_seqlens_q_padded, + size_t hg, size_t d_qk, size_t d_v, + int32_t *cu_seqlens_q_padded, int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, int32_t *offsets_k, int32_t *offsets_v, int32_t *offsets_o); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 342c53bc7f..ae08f2a4aa 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -22,7 +22,7 @@ extern "C" { /*! \enum NVTE_QKV_Layout * \brief Memory layouts of QKV tensors. * `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, number of heads, - * head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. + * head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. * `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length * or padded to the same length, and `THD`-based layouts are used when sequences have * different lengths in a batch. @@ -147,15 +147,16 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); * \param[in] num_gqa_groups The number of heads in K, V. * \param[in] max_seqlen_q The sequence length of Q. * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim The head dimension of Q, K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, - int64_t window_size_right); + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version new file mode 100644 index 0000000000..0683ec01ea --- /dev/null +++ b/transformer_engine/common/libtransformer_engine.version @@ -0,0 +1,4 @@ +{ + global: *nvte*; *transformer_engine*; + local: *; +}; diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 3200c8a019..05adbd624c 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -5,21 +5,50 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging import ctypes +from importlib.metadata import version -from transformer_engine.common import get_te_path +from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_jax" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[jax]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[jax]==VERSION'", + module_name, + ) + extension = _get_sys_extension() try: so_dir = get_te_path() / "transformer_engine" - so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: so_dir = get_te_path() - so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index f9b5156847..56359646b1 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -4,12 +4,14 @@ """JAX/TE custom ops for activation""" from typing import Tuple, Sequence, Union, Callable import operator -from functools import reduce +from functools import reduce, partial +import jax import jax.numpy as jnp from jax import core, dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import NVTE_Activation_Type @@ -21,7 +23,9 @@ jax_dtype_to_te_dtype, jax_dtype_to_ir_dtype, get_padded_spec, + is_ffi_enabled, ) +from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP @@ -42,6 +46,35 @@ } +def _convert_to_activation_function(fn_or_string): + """Convert a string to an activation function.""" + if fn_or_string == "linear": + return lambda x: x + if fn_or_string == "quick_gelu": + return lambda x: jax.nn.sigmoid(1.702 * x) * x + if fn_or_string == "squared_relu": + return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if isinstance(fn_or_string, str): + return getattr(jax.nn, fn_or_string) + if callable(fn_or_string): + return fn_or_string + raise ValueError(f"Unsupported {fn_or_string} to an activation function") + + +def _jax_act_lu(inputs, activation_type): + """ + JAX native activation implementation + """ + x = jnp.split(inputs, len(activation_type), axis=-2) + acts = [] + for idx, act_fn in enumerate(activation_type): + x_i = _convert_to_activation_function(act_fn)(x[idx]) + acts.append(x_i) + x = reduce(operator.mul, acts) + x = jnp.squeeze(x, axis=-2) + return x + + class ActLuPrimitive(BasePrimitive): """ Activation Forward Primitive @@ -78,25 +111,29 @@ def lowering(ctx, x, *, act_enum): """ (x_aval,) = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]] - - out_types = [ - ir.RankedTensorType.get(out_shape, ir_x_type.element_type), - ] - operands = [x] - operand_shapes = [ir_x_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - hidden_size = ir_x_shape[-1] - batch_size = reduce(operator.mul, ir_x_shape[:-2]) - in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor( - (batch_size, hidden_size), in_dtype, in_dtype, act_enum - ) + if is_ffi_enabled(): + name = "te_act_lu_ffi" + out = ffi.ffi_lowering(name)(ctx, x, act_enum=act_enum) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]] + + out_types = [ + ir.RankedTensorType.get(out_shape, ir_x_type.element_type), + ] + operands = [x] + operand_shapes = [ir_x_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + hidden_size = ir_x_shape[-1] + batch_size = reduce(operator.mul, ir_x_shape[:-2]) + in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor( + (batch_size, hidden_size), in_dtype, in_dtype, act_enum + ) - out = custom_caller(ActLuPrimitive.name, args, opaque, False) + out = custom_caller(ActLuPrimitive.name, args, opaque, False) return out @@ -155,7 +192,10 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) Input shape: (N, 1, H) for non-gated activations (N, 2, H) for gated activations """ - act_type_id = ActivationEnum[activation_type] + if not ActLuPrimitive.enabled(): + return _jax_act_lu(inputs, activation_type) + + act_type_id = ActivationEnum[activation_type].value return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id) @@ -197,34 +237,38 @@ def lowering(ctx, dz, x, *, act_enum): in_aval, gi_aval = ctx.avals_in assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert gi_aval.dtype == in_aval.dtype - ir_in_type = ir.RankedTensorType(dz.type) - ir_in_shape = ir_in_type.shape - gi_type = ir.RankedTensorType(x.type) - gi_shape = gi_type.shape - # assert ir_in_shape == gi_shape - for axis in range(len(ir_in_shape) - 1): - assert ir_in_shape[axis] == gi_shape[axis] - - ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) - i_hidden_size = ir_in_shape[-1] - g_hidden_size = gi_shape[-1] - assert i_hidden_size == g_hidden_size - out_dtype = ir_in_type.element_type - out_shape = gi_shape - - out_types = [ - ir.RankedTensorType.get(out_shape, out_dtype), - ] - operands = [dz, x] - operand_shapes = [ir_in_shape, gi_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor( - (ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum - ) + if is_ffi_enabled(): + name = "te_dact_lu_ffi" + out = ffi.ffi_lowering(name)(ctx, dz, x, act_enum=act_enum) + else: + ir_in_type = ir.RankedTensorType(dz.type) + ir_in_shape = ir_in_type.shape + gi_type = ir.RankedTensorType(x.type) + gi_shape = gi_type.shape + # assert ir_in_shape == gi_shape + for axis in range(len(ir_in_shape) - 1): + assert ir_in_shape[axis] == gi_shape[axis] + + ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) + i_hidden_size = ir_in_shape[-1] + g_hidden_size = gi_shape[-1] + assert i_hidden_size == g_hidden_size + out_dtype = ir_in_type.element_type + out_shape = gi_shape + + out_types = [ + ir.RankedTensorType.get(out_shape, out_dtype), + ] + operands = [dz, x] + operand_shapes = [ir_in_shape, gi_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor( + (ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum + ) - out = custom_caller(DActLuPrimitive.name, args, opaque, False) + out = custom_caller(DActLuPrimitive.name, args, opaque, False) return out @@ -286,7 +330,11 @@ def dact_lu( dact_lu fusion wrapper Return dgated_act_lu(inputs) """ - act_type_id = ActivationEnum[activation_type] + if not DActLuPrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs) + return vjp_func(inputs)[0] + + act_type_id = ActivationEnum[activation_type].value return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id) @@ -443,7 +491,12 @@ def act_lu_fp8( Input shape: (N, 1, H) for non-gated activations (N, 2, H) for gated activations """ - act_type_id = ActivationEnum[activation_type] + if not ActLuFp8Primitive.enabled(): + act_lu_output = _jax_act_lu(x, activation_type) + casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype) + return casted_output, updated_amax + + act_type_id = ActivationEnum[activation_type].value return ActLuFp8Primitive.outer_primitive.bind( x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id ) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 4e94de08c4..76ccec363b 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3,8 +3,9 @@ # See LICENSE for license information. """JAX/TE custom ops for attention""" from dataclasses import dataclass -from functools import partial, reduce +from functools import partial, reduce, cache import operator +import os from typing import Optional, Tuple import warnings @@ -30,6 +31,7 @@ jax_dtype_to_te_dtype, te_dtype_to_jax_dtype, get_padded_spec, + get_cudnn_version, ) from ..sharding import ( all_reduce_sum_along_dp_fsdp, @@ -83,6 +85,12 @@ def get_fused_attn_backend(self): self.head_dim, ) + @staticmethod + @cache + def is_non_deterministic_allowed(): + """Check if non-deterministic kernels are allowed""" + return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + @staticmethod def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): """Parse qkv aval""" @@ -364,6 +372,7 @@ def lowering( jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), ) out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) @@ -393,12 +402,12 @@ def impl( if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: - def _fix_len_take(x, condition): + def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape x = x.flatten() size = x.size indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] - y = jnp.take(x, indices, fill_value=-1) + y = jnp.take(x, indices, fill_value=fill_value) return jnp.reshape(y, x_shape) def convert_to_2d(offsets, batch, max_seqlen): @@ -425,9 +434,16 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_batch = reduce(operator.mul, k.shape[:-3]) # Gather valid q_seqlen, which is greater than 0 + # cuDNN version < 9.3.0: # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] - q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0) - kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0) + # cuDNN version >= 9.3.0, which supports act_seqlen = 0 + # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]] + if get_cudnn_version() >= (9, 3, 0): + fill_value = 0 + else: + fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) + kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) # Flatten the offset calculation # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] @@ -634,6 +650,8 @@ def abstract( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) + deterministic = not FusedAttnHelper.is_non_deterministic_allowed() + input_batch = reduce(operator.mul, batch_shape) wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( input_batch, @@ -651,6 +669,7 @@ def abstract( qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training, + deterministic, max_segments_per_seq, ) @@ -756,6 +775,7 @@ def lowering( jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), ) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) @@ -788,13 +808,13 @@ def impl( if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: - def _fix_len_take(x, condition): + def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape x = x.flatten() size = x.size indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] # TODO(rewang): try indices_are_sorted - y = jnp.take(x, indices, fill_value=-1) + y = jnp.take(x, indices, fill_value=fill_value) return jnp.reshape(y, x_shape) def convert_to_2d(offsets, batch, max_seqlen): @@ -821,9 +841,16 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_batch = reduce(operator.mul, k.shape[:-3]) # Gather valid q_seqlen, which is greater than 0 + # cuDNN version < 9.3.0: # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] - q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0) - kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0) + # cuDNN version >= 9.3.0, which supports act_seqlen = 0 + # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]] + if get_cudnn_version() >= (9, 3, 0): + fill_value = 0 + else: + fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) + kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) # Flatten the offset calculation # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 88fab695d6..3d88c1f078 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -2,6 +2,8 @@ # # See LICENSE for license information. """JAX/TE base custom ops""" +import os +import re from abc import ABCMeta, abstractmethod from functools import partial @@ -17,6 +19,21 @@ class BasePrimitive(metaclass=ABCMeta): jax primitive """ + name = None + + @classmethod + def enabled(cls): + """ + A custom call is marked as disabled if the `cls.name` does not fully match the + `NVTE_JAX_CUSTOM_CALLS_RE` pattern. + By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names. + For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`. + """ + pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*") + pattern = re.compile(pattern) + is_enabled = pattern.fullmatch(cls.name) is not None + return is_enabled + @staticmethod @abstractmethod def abstract(): diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 36396a977c..8e58ed3bed 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -3,12 +3,14 @@ # See LICENSE for license information. """JAX/TE custom call""" from dataclasses import dataclass +from enum import IntEnum from jax.lib import xla_client from jax.interpreters import mlir from transformer_engine import transformer_engine_jax +from .misc import is_ffi_enabled try: from jaxlib.hlo_helpers import custom_call @@ -17,8 +19,25 @@ # version, so we still need this import. pass + +class CustomCallAPIVersion(IntEnum): + """Enum for selecting between old and new custom call registration API""" + + OPAQUE = 0 + FFI = 1 + + for _name, _value in transformer_engine_jax.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") + if _name.endswith("_ffi"): + if is_ffi_enabled(): + # COMMAND_BUFFER_COMPATIBLE i.e. cudaGraph enabled by default + xla_client.register_custom_call_target( + _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value + ) + else: + xla_client.register_custom_call_target( + _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value + ) @dataclass @@ -79,7 +98,7 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): result_layouts=args.output_layouts, backend_config=opaque, has_side_effect=has_side_effect, - **kwargs + **kwargs, ).results else: # Need to disable one pylint error as the second function @@ -93,6 +112,6 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): result_layouts=args.output_layouts, backend_config=opaque, has_side_effect=has_side_effect, - **kwargs + **kwargs, ) return out diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index b27e97d7b5..58b8db4c88 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -3,12 +3,20 @@ # See LICENSE for license information. """JAX/TE miscellaneous for custom ops""" +import os +import functools +from typing import Tuple +from importlib.metadata import version as get_pkg_version +from packaging.version import Version as PkgVersion + import numpy as np + import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import dtype_to_ir_type from transformer_engine.transformer_engine_jax import DType as TEDType +from transformer_engine import transformer_engine_jax from ..sharding import get_padded_spec as te_get_padded_spec @@ -128,3 +136,34 @@ def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): *shape[transpose_axis_boundary:], *shape[transpose_start_idx:transpose_axis_boundary], ) + + +@functools.lru_cache(maxsize=None) +def get_cudnn_version() -> Tuple[int, int, int]: + """Runtime cuDNN version (major, minor, patch)""" + encoded_version = transformer_engine_jax.get_cudnn_version() + major_version_magnitude = 1000 if encoded_version < 90000 else 10000 + major, encoded_version = divmod(encoded_version, major_version_magnitude) + minor, patch = divmod(encoded_version, 100) + return (major, minor, patch) + + +@functools.lru_cache(maxsize=None) +def jax_version_meet_requirement(version: str): + """ + Helper function checking if required JAX version is available + """ + jax_version = PkgVersion(get_pkg_version("jax")) + jax_version_required = PkgVersion(version) + return jax_version >= jax_version_required + + +def is_ffi_enabled(): + """ + Helper function checking if XLA Custom Call with FFI is enabled + """ + is_supported = jax_version_meet_requirement("0.4.31") + # New APIs with FFI are enabled by default + is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1")) + assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value" + return is_supported and is_enabled diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 59468db0da..caf9272b02 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -2,11 +2,12 @@ # # See LICENSE for license information. """JAX/TE custom ops for normalization""" -from functools import partial, reduce +from functools import partial, reduce, cache import operator import os import warnings +import jax import jax.numpy as jnp from jax import core, dtypes from jax.interpreters import mlir @@ -25,6 +26,7 @@ jax_dtype_to_ir_dtype, te_dtype_to_jax_dtype, ) +from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp @@ -38,6 +40,18 @@ ] +@cache +def get_forward_sm_margin(): + """Retrieves the number of stream multiprocessors (SM) reserved for other kernels""" + return int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + + +@cache +def get_backward_sm_margin(): + """Retrieves the number of stream multiprocessors (SM) reserved for other kernels""" + return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + + class LayerNormFwdPrimitive(BasePrimitive): """ Layer Normalization Forward Primitive @@ -75,6 +89,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): True, kwargs["zero_centered_gamma"], kwargs["epsilon"], + get_forward_sm_margin(), ) wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -134,7 +149,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): operand_shapes = [x_shape, g_shape, b_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_forward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -239,12 +254,77 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): register_primitive(LayerNormFwdPrimitive) +def _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps): + """ + JAX native layernorm implementation + """ + x_ = jnp.asarray(x, jnp.float32) + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + gamma += 1.0 + return jnp.asarray(normed_input * gamma + beta).astype(x.dtype) + + +def _jax_rmsnorm(x, gamma, zero_centered_gamma, eps): + """ + JAX native rmsnorm implementation + """ + x_ = jnp.asarray(x, jnp.float32) + var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + normed_input = x_ * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + gamma += 1.0 + return jnp.asarray(normed_input * gamma).astype(x.dtype) + + +def _jax_layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps): + """ + JAX native layernorm fp8 implementation + """ + x_ = jnp.asarray(x, jnp.float32) + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(var + eps) + normed_input = (x_ - mean) * rsigma + if zero_centered_gamma: + gamma += 1.0 + output = normed_input * gamma + beta + casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype) + return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax + + +def _jax_rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps): + """ + JAX native rmsnorm fp8 implementation + """ + x_ = jnp.asarray(x, jnp.float32) + var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(var + eps) + normed_input = x_ * rsigma + if zero_centered_gamma: + gamma += 1.0 + output = normed_input * gamma + casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype) + return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax + + def layernorm_fwd( x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float ): """ Wrapper for TE layernorm fwd """ + if not LayerNormFwdPrimitive.enabled(): + x_ = jnp.asarray(x, jnp.float32) + mu = jnp.mean(x_, axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon) + return ( + _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon), + jnp.squeeze(mu, axis=-1), + jnp.squeeze(rsigma, axis=-1), + ) return LayerNormFwdPrimitive.outer_primitive.bind( x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) @@ -287,6 +367,7 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): True, kwargs["zero_centered_gamma"], kwargs["epsilon"], + get_backward_sm_margin(), ) ) wkspace_aval = dx_aval.update( @@ -353,7 +434,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_backward_sm_margin() wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] opaque = transformer_engine_jax.pack_norm_descriptor( @@ -468,12 +549,21 @@ def layernorm_bwd( mu: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray, + beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float, ): """ Wrapper for TE layernorm bwd """ + if not LayerNormBwdPrimitive.enabled(): + _, vjp_func = jax.vjp( + partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon), + x, + gamma, + beta, + ) + return vjp_func(dz) return LayerNormBwdPrimitive.outer_primitive.bind( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) @@ -515,6 +605,7 @@ def abstract(x_aval, gamma_aval, **kwargs): False, False, kwargs["epsilon"], + get_forward_sm_margin(), ) wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -562,7 +653,7 @@ def lowering(ctx, x, gamma, *, epsilon): operand_shapes = [x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_forward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -655,6 +746,12 @@ def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float): """ Wrapper for TE rmsnorm fwd """ + if not RmsNormFwdPrimitive.enabled(): + x_ = jnp.asarray(x, jnp.float32) + rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon) + return _jax_rmsnorm(x, gamma, zero_centered_gamma=False, eps=epsilon), jnp.squeeze( + rsigma, axis=-1 + ) return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon) @@ -694,6 +791,7 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): False, False, kwargs["epsilon"], + get_backward_sm_margin(), ) ) wkspace_aval = dx_aval.update( @@ -747,7 +845,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_backward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -852,6 +950,11 @@ def rmsnorm_bwd( """ Wrapper for TE layernorm bwd """ + if not RmsNormBwdPrimitive.enabled(): + _, vjp_func = jax.vjp( + partial(_jax_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, gamma + ) + return vjp_func(dz) return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) @@ -902,6 +1005,7 @@ def abstract( True, zero_centered_gamma, epsilon, + get_forward_sm_margin(), ) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) @@ -989,7 +1093,7 @@ def lowering( ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_forward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -1148,6 +1252,17 @@ def layernorm_fwd_fp8( """ Wrapper for TE layernorm fwd (fp8 out) """ + if not LayerNormFwdFp8Primitive.enabled(): + return _jax_layernorm_fp8( + x, + gamma, + beta, + scale, + amax, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon, + ) return LayerNormFwdFp8Primitive.outer_primitive.bind( x, gamma, @@ -1198,6 +1313,7 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp False, False, epsilon, + get_forward_sm_margin(), ) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) @@ -1267,7 +1383,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + sm_margin = get_forward_sm_margin() opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, @@ -1387,6 +1503,10 @@ def rmsnorm_fwd_fp8( """ Wrapper for TE rmsnorm fwd (fp8 out) """ + if not RmsNormFwdFp8Primitive.enabled(): + return _jax_rmsnorm_fp8( + x, gamma, scale, amax, out_dtype=out_dtype, zero_centered_gamma=False, eps=epsilon + ) return RmsNormFwdFp8Primitive.outer_primitive.bind( x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 40974b07b9..2c529e71c8 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -4,6 +4,7 @@ """JAX/TE custom ops for quantization""" from typing import Tuple +import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir @@ -26,6 +27,26 @@ __all__ = ["cast_fp8"] +def _jax_quantize(x, scale, q_dtype): + """ + Quantize with scale + """ + compute_dtype = scale.dtype + dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype) + scaled_x = x.astype(compute_dtype) * scale + clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max) + return clipped_scaled_x.astype(q_dtype) + + +def _jax_cast_fp8(inputs, scale, amax, out_dtype): + """ + JAX native fp8 casting implementation + """ + casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) + return casted_output, updated_amax + + class CastFP8Primitive(BasePrimitive): """ Cast Primitive @@ -157,4 +178,6 @@ def cast_fp8( Cast wrapper Return FP8 tensor """ + if not CastFP8Primitive.enabled(): + return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype) return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index c2dfb65e41..bf92c00de3 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -7,6 +7,7 @@ import operator import warnings +import jax import jax.numpy as jnp from jax import core, dtypes from jax.interpreters.mlir import ir @@ -31,6 +32,30 @@ ] +def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): + return jax.nn.softmax(scale_factor * logits) + + +def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): + if mask is not None: + logits += jax.lax.select( + mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.0).astype(logits.dtype), + ) + return jax.nn.softmax(logits * scale_factor) + + +def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): + mask = 1 - jnp.tril(jnp.ones_like(logits)) + logits += jax.lax.select( + mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.0).astype(logits.dtype), + ) + return jax.nn.softmax(logits * scale_factor) + + def is_softmax_kernel_available( softmax_type: SoftmaxType, batch: int, @@ -395,6 +420,8 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: scaled_softmax_forward wrapper Return FP16/BF16 tensor """ + if not ScaledSoftmaxFwdPrimitive.enabled(): + return _jax_scaled_softmax(logits, scale_factor) return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) @@ -469,12 +496,16 @@ def partition(scale_factor, mesh, arg_infos, result_infos): def scaled_softmax_bwd( - dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float + dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float ) -> jnp.ndarray: """ scaled_backward wrapper Return FP16/BF16 tensor """ + if not ScaledSoftmaxBwdPrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits) + return vjp_func(dz)[0] + return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor ) @@ -625,6 +656,8 @@ def scaled_masked_softmax_fwd( scaled_masked_softmax_forward wrapper Return FP16/BF16 tensor """ + if not ScaledMaskedSoftmaxFwdPrimitive.enabled(): + return _jax_scaled_masked_softmax(logits, mask, scale_factor) return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, mask, scale_factor=scale_factor ) @@ -704,12 +737,21 @@ def partition(scale_factor, mesh, arg_infos, result_infos): def scaled_masked_softmax_bwd( - dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float + dz: jnp.ndarray, + softmax_out: jnp.ndarray, + logits: jnp.ndarray, + mask: jnp.ndarray, + scale_factor: float, ) -> jnp.ndarray: """ scaled_masked_backward wrapper Return FP16/BF16 tensor """ + if not ScaledMaskedSoftmaxBwdPrimitive.enabled(): + _, vjp_func = jax.vjp( + partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask + ) + return vjp_func(dz)[0] return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor ) @@ -806,6 +848,8 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl scaled_upper_triang_masked_softmax_forward wrapper Return FP16/BF16 tensor """ + if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled(): + return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor) return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, scale_factor=scale_factor ) @@ -893,12 +937,17 @@ def partition(scale_factor, mesh, arg_infos, result_infos): def scaled_upper_triang_masked_softmax_bwd( - dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float + dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float ) -> jnp.ndarray: """ scaled_upper_triang_masked_backward wrapper Return FP16/BF16 tensor """ + if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled(): + _, vjp_func = jax.vjp( + partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits + ) + return vjp_func(dz)[0] return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor ) diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index 9102b55cae..e503792dc0 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -6,10 +6,12 @@ from typing import Tuple, Sequence, Union, Callable import operator +import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType @@ -24,8 +26,11 @@ get_padded_spec, multidim_transpose, normalize_axis_boundary, + is_ffi_enabled, ) from .activation import ActivationEnum +from .activation import _jax_act_lu +from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp @@ -38,6 +43,27 @@ ] +def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary): + """ + JAX native transpose implementation + """ + axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary) + return jnp.transpose(inputs, axes=axes) + + +def _jax_cast_transpose( + inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary +): + """ + JAX native cast_transpose implementation + """ + casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype) + casted_transposed_output = _jax_transpose( + casted_output, static_axis_boundary, transpose_axis_boundary + ) + return casted_output, casted_transposed_output, updated_amax + + class TransposePrimitive(BasePrimitive): """ Transpose Primitive @@ -176,6 +202,8 @@ def transpose( """ transpose wrapper """ + if not TransposePrimitive.enabled(): + return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary) return TransposePrimitive.outer_primitive.bind( x, static_axis_boundary=static_axis_boundary, @@ -236,45 +264,49 @@ def lowering( assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - if static_axis_boundary >= 0: - for i in range(static_axis_boundary + 1): - assert ir_x_shape[i] == 1 - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - transposed_x_shape = multidim_transpose( - ir_x_shape, static_axis_boundary, transpose_axis_boundary - ) - - out_types = [ - ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), - ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - contracted_x_shape = ( - reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), - reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), - ) - opaque = transformer_engine_jax.pack_common_descriptor( - contracted_x_shape, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - ) - - out = custom_caller( - CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2} - ) + if is_ffi_enabled(): + name = "te_cast_transpose_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})( + ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary + ) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + if static_axis_boundary >= 0: + for i in range(static_axis_boundary + 1): + assert ir_x_shape[i] == 1 + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + transposed_x_shape = multidim_transpose( + ir_x_shape, static_axis_boundary, transpose_axis_boundary + ) + out_types = [ + ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + contracted_x_shape = ( + reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), + reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), + ) + opaque = transformer_engine_jax.pack_common_descriptor( + contracted_x_shape, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + ) + out = custom_caller( + CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2} + ) return out @staticmethod @@ -381,6 +413,15 @@ def cast_transpose( cast transpose wrapper Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` """ + if not CastTransposePrimitive.enabled(): + return _jax_cast_transpose( + x, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) return CastTransposePrimitive.outer_primitive.bind( x, amax, @@ -631,6 +672,28 @@ def dbias_cast_transpose( if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes + if not DBiasCastTransposePrimitive.enabled(): + casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose( + dz, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + dbias = jnp.sum( + dz, + axis=tuple( + range( + transpose_axis_boundary + if transpose_axis_boundary > 0 + else transpose_axis_boundary + dz.ndim + ) + ), + keepdims=False, + ) + return casted_dz, cast_transposed_dz, dbias, updated_amax + return DBiasCastTransposePrimitive.outer_primitive.bind( dz, amax, @@ -947,6 +1010,31 @@ def dact_lu_dbias_cast_transpose( if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes + if not DActLuDBiasCastTransposePrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x) + (dx,) = vjp_func(dz) + casted_dx, cast_transposed_dx, updated_amax = _jax_cast_transpose( + dx, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + dbias = jnp.squeeze( + jnp.sum( + dx, + axis=tuple( + range( + transpose_axis_boundary + if transpose_axis_boundary > 0 + else transpose_axis_boundary + dx.ndim + ) + ), + ) + ) + return casted_dx, cast_transposed_dx, dbias, updated_amax + act_type_id = ActivationEnum[activation_type] return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( dz, @@ -1161,6 +1249,17 @@ def dgated_act_lu_cast_transpose( Return FP8(dgated_act_lu(inputs)) """ act_type_id = ActivationEnum[activation_type] + if not DgatedActLuCastTransposePrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x) + (dx,) = vjp_func(dz) + return _jax_cast_transpose( + dx, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=-2, + ) return DgatedActLuCastTransposePrimitive.outer_primitive.bind( dz, x, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c541fb8afa..b872370715 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -13,8 +13,6 @@ #include #include #include -#include -#include #include #include @@ -27,23 +25,14 @@ #include "common/common.h" #include "common/util/logging.h" +#include "extensions/ffi.h" +#include "extensions/misc.h" +#include "transformer_engine/activation.h" #include "utils.h" namespace transformer_engine { namespace jax { -constexpr int kMaxNumDim = 8; - -// TODO: Rename Shape to ??? -struct Shape { - int num_dim; - size_t dims[kMaxNumDim]; - - void from_vector(const std::vector &shape); - - std::vector to_vector() const; -}; - // Phuong: These 3 functions need to stay in the header file for compilation purpose // 1. inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } @@ -62,8 +51,6 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) { return reinterpret_cast(opaque); } -std::vector MakeShapeVector(NVTEShape shape); - // Packing struct CustomCallCommonDescriptor { @@ -147,6 +134,7 @@ struct CustomCallFusedAttnDescriptor { DType dtype; DType wkspace_dtype; bool is_training; + bool deterministic; }; pybind11::bytes PackCustomCallFusedAttnDescriptor( @@ -154,7 +142,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training); + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic); // Transpose @@ -165,6 +154,8 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler); + void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); // Activation @@ -177,6 +168,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler); + pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype); @@ -191,7 +186,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, bool is_layer_norm, bool zero_centered_gamma, - float eps); + float eps, int sm_margin); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -201,7 +196,7 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, bool zero_centered_gamma, - float eps); + float eps, int sm_margin); void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -260,7 +255,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - size_t max_segments_per_seq); + bool deterministic, size_t max_segments_per_seq); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 51563a8ccd..1e8998b365 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -3,15 +3,16 @@ * * See LICENSE for license information. ************************************************************************/ - #include "transformer_engine/activation.h" #include "extensions.h" #include "transformer_engine/transpose.h" +#include "xla/ffi/api/c_api.h" namespace transformer_engine { namespace jax { +// TODO: We won't need this function anymore when we move to the new XLA custom calls size_t get_activation_len(NVTE_Activation_Type activation_enum) { switch (activation_enum) { case NVTE_Activation_Type::GELU: @@ -43,8 +44,7 @@ size_t get_activation_len(NVTE_Activation_Type activation_enum) { void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, cudaStream_t stream, float *scale_inverse, float *amax, void *output, - NVTE_Activation_Type act_enum) { - auto act_len = get_activation_len(act_enum); + NVTE_Activation_Type act_enum, size_t act_len) { auto input_shape = std::vector{m, n * act_len}; auto output_shape = std::vector{m, n}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); @@ -95,12 +95,39 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto act_enum = static_cast(desc.act_enum); - ; + auto act_len = get_activation_len(act_enum); ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output, - act_enum); + act_enum, act_len); } +Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, + int64_t act_enum) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *output = output_buf->untyped_data(); + + auto input_dims = input_buf.dimensions(); + auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>()); + auto n = input_dims.back(); + auto act_len = input_dims.end()[-2]; + auto act_type = static_cast(act_enum); + + ActLuImpl(input, m, n, in_dtype, out_dtype, nullptr, stream, nullptr, nullptr, output, act_type, + act_len); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("act_enum")); + void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; float *amax = reinterpret_cast(buffers[1]); @@ -119,10 +146,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto act_enum = static_cast(desc.act_enum); - ; + auto act_len = get_activation_len(act_enum); ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output, - act_enum); + act_enum, act_len); } void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -134,7 +161,6 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto act_enum = static_cast(desc.act_enum); - ; auto act_len = get_activation_len(act_enum); auto input_shape = std::vector{m, n}; @@ -182,6 +208,76 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq } } +Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, + Result_Type output_buf, int64_t act_enum) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *act_input = act_input_buf.untyped_data(); + auto *output = output_buf->untyped_data(); + + auto act_input_dims = act_input_buf.dimensions(); + auto m = + std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>()); + auto n = act_input_dims.back(); + auto act_len = act_input_dims.end()[-2]; + + auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n * act_len}; + auto output_shape = std::vector{m, n * act_len}; + + auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, static_cast(in_dtype)); + auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype)); + + auto act_type = static_cast(act_enum); + switch (act_type) { + case NVTE_Activation_Type::GELU: + nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::GEGLU: + nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SILU: + nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SWIGLU: + nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::RELU: + nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::REGLU: + nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGELU: + nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGEGLU: + nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SRELU: + nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Type::SREGLU: + nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // act_input + .Ret() // output + .Attr("act_enum")); + pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { auto input_shape = std::vector{batch_size, hidden_size}; diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 640869ac36..1d367f5cc1 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, auto backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - head_dim, -1, -1); + head_dim, head_dim, -1, -1); return backend; } @@ -139,7 +139,13 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; - for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) { + size_t min_num_segments = input_batch; + auto cudnn_runtime_version = cudnnGetVersion(); + if (is_ragged && cudnn_runtime_version >= 90300) { + // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 + min_num_segments = input_batch * max_segments_per_seq; + } + for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { // the last one is the largest which will be the returned workspace size auto q_cu_seqlens_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); @@ -227,14 +233,19 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments if (is_ragged) { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; + auto cudnn_runtime_version = cudnnGetVersion(); + if (cudnn_runtime_version >= 90300) { + num_segments = input_batch * max_segments_per_seq; + } else { + // workspace can be reused here as it is not used with cuDNN graph at the same time + size_t runtime_num_segments_q = + GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); + size_t runtime_num_segments_kv = + GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); + NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); + NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); + num_segments = runtime_num_segments_q; + } cudaMemsetAsync(output, 0, input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream); } @@ -255,10 +266,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto backend = - nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), - qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + head_dim, head_dim, -1, -1); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -325,7 +336,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - size_t max_segments_per_seq) { + bool deterministic, size_t max_segments_per_seq) { // For qkv_packed auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); @@ -366,7 +377,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; - for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) { + size_t min_num_segments = input_batch; + auto cudnn_runtime_version = cudnnGetVersion(); + if (is_ragged && cudnn_runtime_version >= 90300) { + // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 + min_num_segments = input_batch * max_segments_per_seq; + } + for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { // the last one is the largest which will be the returned workspace size auto q_cu_seqlens_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); @@ -375,13 +392,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, -1, -1, deterministic, + query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -391,7 +409,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, true, query_workspace_tensor.data(), nullptr); + -1, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -402,7 +420,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, true, query_workspace_tensor.data(), nullptr); + -1, deterministic, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -450,6 +468,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; auto dtype = descriptor.dtype; + auto deterministic = descriptor.deterministic; auto max_segments_per_seq = descriptor.max_segments_per_seq; /* Input tensors */ @@ -460,14 +479,19 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments if (is_ragged) { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; + auto cudnn_runtime_version = cudnnGetVersion(); + if (cudnn_runtime_version >= 90300) { + num_segments = input_batch * max_segments_per_seq; + } else { + // workspace can be reused here as it is not used with cuDNN graph at the same time + size_t runtime_num_segments_q = + GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); + size_t runtime_num_segments_kv = + GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); + NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); + NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); + num_segments = runtime_num_segments_q; + } } auto q_cu_seqlens_tensor = @@ -486,10 +510,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - auto backend = - nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), - qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + head_dim, head_dim, -1, -1); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, rng_state, bias); @@ -517,7 +541,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream); + bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -544,7 +568,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, + dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q = buffers[0]; @@ -580,8 +604,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, - workspace_tensor.data(), stream); + dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, + deterministic, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp new file mode 100644 index 0000000000..19fd50cbd1 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#include "extensions/ffi.h" + +#include + +#include "common/util/logging.h" + +namespace transformer_engine { +namespace jax { + +// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186 +DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { + switch (type) { + case xla::ffi::DataType::F16: + return DType::kFloat16; + break; + case xla::ffi::DataType::F32: + return DType::kFloat32; + break; + case xla::ffi::DataType::BF16: + return DType::kBFloat16; + break; + case xla::ffi::DataType::F8E5M2: + return DType::kFloat8E5M2; + break; + case xla::ffi::DataType::F8E4M3FN: + return DType::kFloat8E4M3; + break; + default: + auto type_num = static_cast(type); + NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", + static_cast(type_num)); + break; + } +} + +Error_Type ffi_with_cuda_error_check() { + cudaError_t last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + return Error_Type(XLA_FFI_Error_Code_INTERNAL, + std::string("CUDA error: ") + cudaGetErrorString(last_error)); + } + return Error_Type::Success(); +} + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h new file mode 100644 index 0000000000..77132c3fca --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -0,0 +1,25 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include + +namespace transformer_engine { +namespace jax { + +using Buffer_Type = xla::ffi::AnyBuffer; +using Result_Type = xla::ffi::Result; +using Error_Type = xla::ffi::Error; +using FFI = xla::ffi::Ffi; +using FFI_Stream_Type = xla::ffi::PlatformStream; + +DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type); +Error_Type ffi_with_cuda_error_check(); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h new file mode 100644 index 0000000000..7f6179e91c --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -0,0 +1,30 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include +#include + +namespace transformer_engine { +namespace jax { + +constexpr int kMaxNumDim = 8; + +struct Shape { + int num_dim; + size_t dims[kMaxNumDim]; + + void from_vector(const std::vector &shape); + + std::vector to_vector() const; +}; + +std::vector MakeShapeVector(NVTEShape shape); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 9585e2edf1..fb40400e62 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -3,7 +3,6 @@ * * See LICENSE for license information. ************************************************************************/ - #include "extensions.h" #include "transformer_engine/layer_norm.h" #include "transformer_engine/rmsnorm.h" @@ -14,7 +13,7 @@ namespace jax { pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, bool is_layer_norm, bool zero_centered_gamma, - float eps) { + float eps, int sm_margin) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -27,7 +26,7 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd // dummy tensor wrappers that will carry workspace size info later TensorWrapper dummy_work_tensor, dummy_barrier_tensor; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; if (is_layer_norm) { auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); @@ -54,7 +53,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac DType in_dtype, void *weight, DType w_dtype, void *bias, void *output, DType out_dtype, void *workspace, DType work_dtype, void *barrier, DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale, - float *scale_inv, cudaStream_t stream) { + float *scale_inv, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -71,7 +70,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv); auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); @@ -95,7 +94,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, bool zero_centered_gamma, - float eps) { + float eps, int sm_margin) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -112,7 +111,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid // dummy tensor wrappers that will carry workspace size info later TensorWrapper dummy_work_tensor, dummy_barrier_tensor; TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; // initialize dBeta information here -- layernorm will modify but RMSnorm will not @@ -152,7 +151,7 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace void *weight, DType w_dtype, void *ograd, void *workspace, DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part, - DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, + DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; @@ -174,7 +173,7 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype); auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto workspace_shape = std::vector{wkspace_size}; @@ -228,13 +227,14 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + sm_margin, stream); } void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -263,11 +263,12 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto eps = desc.eps; auto out_dtype = in_dtype; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + sm_margin, stream); } void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -287,6 +288,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto dbeta_part_dtype = desc.dbeta_part_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; auto *ograd = buffers[0]; auto *mu = buffers[1]; @@ -305,7 +307,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, stream); + dbeta_part_dtype, sm_margin, stream); } void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -335,12 +337,13 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + sm_margin, stream); } void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -368,12 +371,13 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; auto out_dtype = in_dtype; LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + sm_margin, stream); } void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -407,12 +411,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si auto dbeta_part_dtype = DType::kByte; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; + auto sm_margin = desc.sm_margin; LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, stream); + dbeta_part_dtype, sm_margin, stream); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 8c948d0a8f..128564db64 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -68,11 +68,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) { + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic) { return PackOpaque(CustomCallFusedAttnDescriptor{ input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, - mask_type, qkv_layout, dtype, wkspace_dtype, is_training}); + mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic}); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 95fe3101c9..0a2172bb1b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -14,6 +14,13 @@ pybind11::capsule EncapsulateFunction(T *fn) { return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); } +template +pybind11::capsule EncapsulateFFI(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be an XLA FFI handler"); + return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + pybind11::dict Registrations() { pybind11::dict dict; dict["te_transpose"] = EncapsulateFunction(Transpose); @@ -44,6 +51,10 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + + dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); + dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); + dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); return dict; } @@ -59,6 +70,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); + m.def("get_cudnn_version", &GetCudnnRuntimeVersion); m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); @@ -113,7 +125,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGELU", NVTE_Activation_Type::QGELU) .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU); + .value("SREGLU", NVTE_Activation_Type::SREGLU) + .export_values(); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 3e53b7521f..7a2e31312a 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -7,6 +7,7 @@ #include "transformer_engine/transpose.h" #include "extensions.h" +#include "xla/ffi/api/ffi.h" namespace transformer_engine { namespace jax { @@ -66,6 +67,61 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size stream); } +Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, + Result_Type input_cast_buf, Result_Type input_cast_trans_buf, + Result_Type amax_out_buf, int64_t transpose_axis) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(input_cast_buf->element_type()); + + auto *input = input_buf.untyped_data(); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *input_cast = input_cast_buf->untyped_data(); + auto *input_cast_trans = input_cast_trans_buf->untyped_data(); + float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + assert(amax == amax_out); + + if (!use_fp8(out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + + auto input_dims = input_buf.dimensions(); + if (transpose_axis < 0) transpose_axis += input_dims.size(); + auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1, + std::multiplies<>()); + auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1, + std::multiplies<>()); + auto input_shape = std::vector{m, n}; + auto input_trans_shape = std::vector{n, m}; + + auto input_tensor = TensorWrapper(input, input_shape, in_dtype); + auto input_cast_tensor = + TensorWrapper(input_cast, input_shape, out_dtype, amax_out, scale, scale_inv); + auto input_cast_trans_tensor = + TensorWrapper(input_cast_trans, input_trans_shape, out_dtype, amax_out, scale, scale_inv); + + nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), + stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // input_cast + .Ret() // input_cast_trans + .Ret() // amax_out + .Attr("transpose_axis")); + pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { auto input_shape = std::vector{batch_size, hidden_size}; diff --git a/transformer_engine/jax/csrc/utils.cu b/transformer_engine/jax/csrc/utils.cu index d9451dca32..8ca34013b3 100644 --- a/transformer_engine/jax/csrc/utils.cu +++ b/transformer_engine/jax/csrc/utils.cu @@ -19,6 +19,8 @@ int GetCudaRuntimeVersion() { return ver; } +size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); } + int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); } __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed, diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index fd3ebe8d8c..32de33bac9 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -22,6 +22,7 @@ namespace transformer_engine { namespace jax { int GetCudaRuntimeVersion(); +size_t GetCudnnRuntimeVersion(); int GetDeviceComputeCapability(int gpu_id); void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index d53a4e5202..c62c2bb77d 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -359,6 +359,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods kernel is not available on the system, a warning will be issued, and the module will automatically fall back to the unfused backend. + .. note:: + The DotProductAttention default setting enables non-deterministic kernels for reduced + workspace requirements and faster computation. Users can disable the non-deterministic + kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable: + + * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels. + * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default). + Parameters ---------- head_dim: int diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index e7364a13b6..4f2e83d9a2 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -69,14 +69,14 @@ def _layernorm_fwd_rule( mu = None else: raise ValueError(f"{layernorm_type=} is not supported.") - return output, (x, mu, rsigma, gamma) + return output, (x, mu, rsigma, gamma, beta) def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz): - x, mu, rsigma, gamma = ctx + x, mu, rsigma, gamma, beta = ctx if layernorm_type == "layernorm": dx, dgamma, dbeta = tex.layernorm_bwd( - dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + dz, x, mu, rsigma, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) elif layernorm_type == "rmsnorm": assert ( @@ -267,6 +267,7 @@ def _layernorm_fp8_dot_fwd_rule( rsigma, x, gamma, + beta, x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32, @@ -300,6 +301,7 @@ def _layernorm_fp8_dot_bwd_rule( rsigma, x, gamma, + beta, x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32, @@ -352,7 +354,14 @@ def _layernorm_fp8_dot_bwd_rule( dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) if layernorm_type == "layernorm": dx, dgamma, dbeta = tex.layernorm_bwd( - dgrad, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + dgrad, + x, + mu, + rsigma, + gamma, + beta, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, ) else: assert ( diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 0017acb80c..90504e4c14 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -344,6 +344,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( mu, rsigma, gamma, + beta, dot_1_output, casted_activation_lu_out, casted_kernel_1, @@ -390,6 +391,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( mu, rsigma, gamma, + beta, dot_1_output, casted_activation_lu_out, casted_kernel_1, @@ -568,7 +570,14 @@ def _fused_layernorm_fp8_mlp_bwd_rule( if layernorm_type == "layernorm": dx, dgamma, dbeta = tex.layernorm_bwd( - dgrad_1, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon + dgrad_1, + x, + mu, + rsigma, + gamma, + beta, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, ) else: assert ( diff --git a/transformer_engine/jax/softmax.py b/transformer_engine/jax/softmax.py index 0a997776ef..c63ee85e5d 100644 --- a/transformer_engine/jax/softmax.py +++ b/transformer_engine/jax/softmax.py @@ -49,18 +49,18 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): else: output = tex.scaled_softmax_fwd(logits, scale_factor) - return output, (output,) + return output, (output, logits, mask) def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz): - (softmax_output,) = ctx + (softmax_output, logits, mask) = ctx if softmax_type is SoftmaxType.SCALED_MASKED: - dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor) elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: - dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor) else: - dgrad = tex.scaled_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor) return (dgrad, None) diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py index 62fa1fe626..50cf2186d6 100644 --- a/transformer_engine/paddle/__init__.py +++ b/transformer_engine/paddle/__init__.py @@ -6,9 +6,41 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging +from importlib.metadata import version + +from transformer_engine.common import is_package_installed + def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_paddle" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[paddle]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[paddle]==VERSION'", + module_name, + ) + from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index e12c0dd3c4..7860da2496 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -593,6 +593,9 @@ def fused_attn_fwd_qkvpacked( if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype) else: @@ -656,6 +659,7 @@ def fused_attn_bwd_qkvpacked( qkv_layout: str = "bs3hd", bias_type: str = "no_bias", attn_mask_type: str = "padding", + deterministic: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed QKV input""" @@ -676,13 +680,19 @@ def fused_attn_bwd_qkvpacked( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) else: dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype) if bias_type != "no_bias": - dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) + if qkv_format == "thd": + dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) + else: + dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) else: dbias = None # execute kernel @@ -706,6 +716,7 @@ def fused_attn_bwd_qkvpacked( bias_type, attn_mask_type, int(qkv_dtype), + deterministic, ) return dqkv, dbias @@ -772,6 +783,9 @@ def fused_attn_fwd_kvpacked( if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) else: @@ -843,6 +857,7 @@ def fused_attn_bwd_kvpacked( qkv_layout: str = "bshd_bs2hd", bias_type: str = "no_bias", attn_mask_type: str = "padding", + deterministic: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed KV input""" @@ -867,6 +882,9 @@ def fused_attn_bwd_kvpacked( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype) @@ -874,7 +892,10 @@ def fused_attn_bwd_kvpacked( dq = paddle.empty(shape=q.shape, dtype=q.dtype) dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype) if bias_type != "no_bias": - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + if qkv_format == "thd": + dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + else: + dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) else: dbias = None # execute kernel @@ -903,6 +924,7 @@ def fused_attn_bwd_kvpacked( bias_type, attn_mask_type, int(qkv_dtype), + deterministic, ) return dq, dkv, dbias @@ -970,6 +992,9 @@ def fused_attn_fwd( if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) else: @@ -1040,6 +1065,7 @@ def fused_attn_bwd( qkv_layout: str = "bshd_bshd_bshd", bias_type: str = "no_bias", attn_mask_type: str = "padding", + deterministic: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Fused Attention BWD for packed KV input""" @@ -1065,6 +1091,9 @@ def fused_attn_bwd( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + if qkv_format == "thd": + set_zero = True if set_zero: dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype) @@ -1074,7 +1103,10 @@ def fused_attn_bwd( dk = paddle.empty(shape=k.shape, dtype=k.dtype) dv = paddle.empty(shape=v.shape, dtype=v.dtype) if bias_type != "no_bias": - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + if qkv_format == "thd": + dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + else: + dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) else: dbias = None # execute kernel @@ -1103,6 +1135,7 @@ def fused_attn_bwd( bias_type, attn_mask_type, int(qkv_dtype), + deterministic, ) return dq, dk, dv, dbias diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 60f06a2188..6ce250432a 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -131,10 +131,10 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) { - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(static_cast(q_dtype), static_cast(kv_dtype), - qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, - num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, + attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, + head_dim, head_dim, -1, -1); return fused_attention_backend; } diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 69569d5584..904d979b8e 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -708,7 +708,8 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor int64_t b, int64_t h, int64_t d, int64_t total_seqs, int64_t max_seqlen, float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, int64_t qkv_type) { + const std::string &attn_mask_type, int64_t qkv_type, + bool deterministic) { TensorWrapper te_dBias; if (bias_type != "no_bias" && dBias) { auto bias_shape = dBias->shape(); @@ -759,22 +760,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), - te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream()); + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, + te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, + deterministic, workspace.data(), QKV.stream()); // allocate memory for workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); // execute kernel - nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), - te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream()); + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, + te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, + attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, + deterministic, workspace.data(), QKV.stream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -884,7 +885,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type) { + int64_t qkv_type, bool deterministic) { TensorWrapper te_dBias; if (bias_type != "no_bias" && dBias) { auto bias_shape = dBias->shape(); @@ -945,7 +946,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, true, workspace.data(), Q.stream()); + -1, -1, deterministic, workspace.data(), Q.stream()); // allocate memory for workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); @@ -957,7 +958,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, true, workspace.data(), Q.stream()); + -1, -1, deterministic, workspace.data(), Q.stream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -1086,7 +1087,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, float attn_scale, float p_dropout, const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type) { + int64_t qkv_type, bool deterministic) { TensorWrapper te_dBias; if (bias_type != "no_bias" && dBias) { auto bias_shape = dBias->shape(); @@ -1149,7 +1150,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream()); + attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); // allocate memory for workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); @@ -1161,7 +1162,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream()); + attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -1657,7 +1658,8 @@ PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) .Outputs({"dQKV", paddle::Optional("dBias")}) .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"}) + "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", + "deterministic: bool"}) .SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked)); @@ -1682,7 +1684,8 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked) .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"}) + "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", + "deterministic: bool"}) .SetInplaceMap({{"_dQ", "dQ"}, {"_dKV", "dKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) @@ -1708,7 +1711,7 @@ PD_BUILD_OP(te_fused_attn_bwd) .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", - "qkv_type: int64_t"}) + "qkv_type: int64_t", "deterministic: bool"}) .SetInplaceMap({{"_dQ", "dQ"}, {"_dK", "dK"}, {"_dV", "dV"}, diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 98e50b9e04..75a3513d14 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -152,6 +152,7 @@ def forward( attn_bias_type, attn_mask_type, is_training, + deterministic, fused_attention_backend, ): """Forward function for FusedAttention with packed QKV input""" @@ -180,6 +181,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.deterministic = deterministic ctx.fused_attention_backend = fused_attention_backend return out @@ -204,6 +206,7 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.deterministic, ) # if no_bias, return dqkv @@ -234,6 +237,7 @@ def forward( attn_bias_type, attn_mask_type, is_training, + deterministic, fused_attention_backend, ): """Forward function for FusedAttention with packed KV input""" @@ -266,6 +270,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.deterministic = deterministic ctx.fused_attention_backend = fused_attention_backend return out @@ -293,6 +298,7 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.deterministic, ) # if no_bias, return dq, dkv @@ -324,6 +330,7 @@ def forward( attn_bias_type, attn_mask_type, is_training, + deterministic, fused_attention_backend, ): """Forward function for FusedAttention with separate Q, K, V tensors""" @@ -357,6 +364,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.deterministic = deterministic ctx.fused_attention_backend = fused_attention_backend return out @@ -385,6 +393,7 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.deterministic, ) # if no_bias, return dq, dk, dv if ctx.attn_bias_type == "no_bias": @@ -404,6 +413,12 @@ class DotProductAttention(paddle.nn.Layer): Argument :attr:`attention_mask` will be ignored in the `forward` call when :attr:`attn_mask_type` is set to `"causal"`. + .. warning:: + + Fused attention backward uses a non-deterministic algorithm when workspace + optimization is not enabled. To use a deterministic algorithm, set the + environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` + Parameters ---------- num_attention_heads: int @@ -458,6 +473,29 @@ def __init__( self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1"))) + self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + + # To use the workspace optimization path for determinism, please + # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0, + # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0. + cudnn_version = paddle.get_cudnn_version() + if 8905 <= cudnn_version < 9000: + if self.deterministic: + # workspace optimization path is deterministic + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" + + # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT + # - unset: enables workspace optimization when required workspace is <= 256MB + # or when bias gradient needs to be computed + # - n: enables workspace optimization when required workspace is <= n bytes + # - -1: enables workspace optimization always + # - 0: disables workspace optimization always + if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ: + if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0": + os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0" + if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": + os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" + if not self.use_fused_attention and backend == "transformer_engine": warnings.warn("Fused attention is not enabled, falling back to Paddle backend") self.backend = "paddle" @@ -603,6 +641,7 @@ def _te_forward( core_attention_bias_type, self.attn_mask_type, self.training, + self.deterministic, self.fused_attention_backend, ) elif self.attention_type == "cross": @@ -637,6 +676,7 @@ def _te_forward( core_attention_bias_type, self.attn_mask_type, self.training, + self.deterministic, self.fused_attention_backend, ) else: diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 20b6f79da6..07ade71905 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -6,25 +6,54 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging import importlib +import importlib.util import sys import torch +from importlib.metadata import version -from transformer_engine.common import get_te_path +from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_torch" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[pytorch]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[pytorch]==VERSION'", + module_name, + ) + extension = _get_sys_extension() try: so_dir = get_te_path() / "transformer_engine" - so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: so_dir = get_te_path() - so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) - module_name = "transformer_engine_torch" spec = importlib.util.spec_from_file_location(module_name, so_path) solib = importlib.util.module_from_spec(spec) sys.modules[module_name] = solib diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f71b469f2d..ff121527d3 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -65,6 +65,8 @@ set_all_rng_states, CudaRNGStatesTracker, graph_safe_rng_available, + gather_along_first_dim, + reduce_scatter_along_first_dim, ) from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo @@ -79,6 +81,7 @@ _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") +_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func @@ -92,17 +95,20 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +# repurpose some unused amax history buffers for partial results of CP fwd and bwd +META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT +META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) +_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL +_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} +_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] +_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") +_stream_handler = logging.StreamHandler() +_stream_handler.setFormatter(_formatter) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) @@ -141,8 +147,10 @@ class AttentionParams: Maximum sequence length of the query tensor. max_seqlen_kv: int, default = 128 Maximum sequence length of the key and value tensors. - head_dim: int, default = 64 - The size of each attention head. + head_dim_qk: int, default = 64 + The size of each attention head in query and key tensors. + head_dim_v: int, default = 64 + The size of each attention head in the value tensor. attn_mask_type: str, default = `no_mask` Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} @@ -181,7 +189,8 @@ class AttentionParams: num_gqa_groups: int = 16 max_seqlen_q: int = 128 max_seqlen_kv: int = 128 - head_dim: int = 64 + head_dim_qk: int = 64 + head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None alibi_slopes_shape: Union[torch.Size, List, None] = None @@ -244,7 +253,8 @@ def get_attention_backend( num_gqa_groups = attention_params.num_gqa_groups max_seqlen_q = attention_params.max_seqlen_q max_seqlen_kv = attention_params.max_seqlen_kv - head_dim = attention_params.head_dim + head_dim_qk = attention_params.head_dim_qk + head_dim_v = attention_params.head_dim_v attn_mask_type = attention_params.attn_mask_type window_size = attention_params.window_size alibi_slopes_shape = attention_params.alibi_slopes_shape @@ -261,6 +271,9 @@ def get_attention_backend( # Run config logger = logging.getLogger("DotProductAttention") + logger.setLevel(_log_level) + if not logger.hasHandlers(): + logger.addHandler(_stream_handler) device_compute_capability = get_device_compute_capability() cudnn_version = get_cudnn_version() run_config = { @@ -313,13 +326,6 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as it requires compute capability sm80+") use_fused_attention = False - # Filter: Context parallelism - if context_parallel and use_unfused_attention: - logger.debug( - "Disabling UnfusedDotProductAttention as it does not support context parallelism" - ) - use_unfused_attention = False - # Filter: Data type if use_flash_attention and ( qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor @@ -351,19 +357,31 @@ def get_attention_backend( use_unfused_attention = False # Filter: Head dimension + if use_flash_attention and head_dim_qk != head_dim_v: + logger.debug("Disabling FlashAttention as it does not support MLA.") + use_flash_attention = False if use_flash_attention and ( - head_dim > 256 - or head_dim % 8 != 0 - or (head_dim > 192 and device_compute_capability not in ((8, 0), (9, 0))) + head_dim_qk > 256 + or head_dim_qk % 8 != 0 + or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) ): logger.debug( - "Disabling FlashAttention due to unsupported head_dim. " - "Supported: head_dim %%8 = 0, head_dim <= 256 (>192 requires sm80/90). " - "Found: head_dim = %s on sm%s.", - head_dim, + "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " + "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " + "head_dim_qk <= 256 (>192 requires sm80/90). " + "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", + head_dim_qk, + head_dim_v, ".".join([str(i) for i in device_compute_capability]), ) use_flash_attention = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") + if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": + logger.debug( + "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", + qkv_layout, + ) + use_fused_attention = False # Filter: QKV layout qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -378,20 +396,101 @@ def get_attention_backend( ) use_flash_attention = False + # Filter: Context parallelism + # qkv_format | attn_mask_type | attn_bias_type | supported backends + # ---------------------------------------------------------------------------------------------------- + # bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention + # | no_mask, causal | | + # | cross-attention: | | + # | no_mask | | + # thd | self-attention: | no_bias | FlashAttention, FusedAttention + # | padding, padding_causal | | if no padding between sequences, + # | cross-attention: | | FusedAttention + # | padding | | if there is padding between sequences + # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v. + if context_parallel and use_unfused_attention: + logger.debug( + "Disabling UnfusedDotProductAttention as it does not support context parallelism" + ) + use_unfused_attention = False + if context_parallel and use_flash_attention: + if "bottom_right" in attn_mask_type: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) + use_flash_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with causal" + " masking for cross-attention" + ) + use_flash_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with bias type" + " of %s", + core_attention_bias_type, + ) + use_flash_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with attention" + " bias for THD format" + ) + use_flash_attention = False + if context_parallel and use_fused_attention: + if "bottom_right" in attn_mask_type: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) + use_fused_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with causal" + " masking for cross-attention" + ) + use_fused_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias type" + " of %s", + core_attention_bias_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with attention" + " bias for THD format" + ) + use_fused_attention = False + elif head_dim_qk != head_dim_v: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with MLA" + ) + use_fused_attention = False + # Filter: Attention mask - # attn_mask_type | supported backends - # ------------------------------------------------------------------- - # no_mask | All - # padding | FlashAttention, FusedAttention - # causal | - # self-attention | All - # cross-attention | FusedAttention - # padding_causal | - # self-attention | FlashAttention, FusedAttention - # cross-attention | FusedAttention - # causal_bottom_right | All - # padding_causal_bottom_right | FlashAttention, FusedAttention - # arbitrary | UnfusedDotProductAttention + # attn_mask_type | attention_mask | supported backends + # ---------------------------------------------------------------------------------------- + # no_mask | None | All + # padding | | All + # self-attention | One tensor in shape [b, 1, 1, sq] | + # cross-attention | Tuple of two tensors in shapes | + # | [b, 1, 1, sq] and [b, 1, 1, skv] | + # causal | None | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # padding_causal | Same as "padding" | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # causal_bottom_right | None | All + # padding_causal_bottom_right | Same as "padding" | + # self-attention | | All + # cross-attention | | FlashAttention, UnfusedDotProductAttention + # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention + # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": if use_flash_attention: logger.debug("Disabling FlashAttention for arbitrary mask") @@ -399,9 +498,6 @@ def get_attention_backend( if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False - if use_unfused_attention and "padding" in attn_mask_type: - logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type) - use_unfused_attention = False if ( use_flash_attention and _flash_attn_2_1_plus @@ -478,11 +574,10 @@ def get_attention_backend( if ( use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]) - and (not _flash_attn_2_3_plus or context_parallel) + and not _flash_attn_2_3_plus ): logger.debug( - "Disabling FlashAttention as sliding window attention requires " - "flash-attn 2.3+ and no context parallelism" + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" ) use_flash_attention = False @@ -556,7 +651,8 @@ def get_attention_backend( num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, + head_dim_qk, + head_dim_v, window_size[0], window_size[1], ) @@ -564,18 +660,6 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None - if ( - use_fused_attention - and context_parallel - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] - ): - logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " - "context parallellism", - int(fused_attention_backend), - ) - use_fused_attention = False - fused_attention_backend = None if ( use_fused_attention and window_size is not None @@ -699,7 +783,7 @@ def get_attention_backend( class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. + to efficiently calculate and store the context during inference. Parameters ---------- @@ -805,6 +889,8 @@ def get_alibi( num_heads: int, max_seqlen_q: int, max_seqlen_kv: int, + actual_seqlens_q: Optional[torch.Tensor] = None, + actual_seqlens_kv: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, bias_dtype: Optional[torch.dtype] = None, bottom_right_alignment: bool = True, @@ -818,6 +904,10 @@ def get_alibi( Maximum sequence length for queries. max_seqlen_kv: int Maximum sequence length for keys and values. + actual_seqlens_q: Optional[torch.Tensor], default = `None` + Actual sequence lengths for queries, in shape [batch_size]. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + Actual sequence lengths for keys and values, in shape [batch_size]. alibi_slopes: Optional[torch.Tensor], default = `None` Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. bias_dtype: Optional[torch.dtype], default = `None` @@ -831,10 +921,12 @@ def get_alibi( alibi_slopes: torch.Tensor ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. alibi_bias: torch.Tensor - ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape, - then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if - `alibi_slopes` is in [batch_size, num_heads], then the bias is in - [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. + ALiBi bias in FP32 or `bias_dtype`. Its shape is + (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, + and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or + (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in + [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and + `actual_seqlens_q` and `actual_seqlens_kv` are not `None`. """ global _alibi_cache if _alibi_cache["_alibi_slopes_require_update"]: @@ -860,17 +952,23 @@ def get_alibi( slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) if _alibi_cache["_alibi_slopes"].dim() == 2: slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) - if bottom_right_alignment: - bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - else: - bias = torch.arange( - 1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda" - ).view(1, 1, 1, max_seqlen_kv) - bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view( + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv ) + if actual_seqlens_q is None and actual_seqlens_kv is None: + if bottom_right_alignment: + bias = bias + max_seqlen_kv - max_seqlen_q + elif actual_seqlens_q is not None and actual_seqlens_kv is not None: + batch_size = actual_seqlens_q.shape[0] + bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + if bottom_right_alignment: + bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + else: + assert ( + False + ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!" bias = bias.abs().mul(-1) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv @@ -1180,11 +1278,32 @@ def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): softmax_lse.copy_(new_scale) -class AttnFuncWithCP(torch.autograd.Function): +@jit_fuser +def get_cu_seqlens_on_cp_rank( + cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half +): + """Compute cu_seqlens of a context parallelism rank""" + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2 + zeros = torch.zeros_like(seqlens) + cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens) + if first_half: + seqlens_1 = seqlens - cp_rank * seqlens_padded + seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded) + cu_seqlens_on_cp_rank[1:].add_(seqlens_1) + if second_half: + seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded + seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded) + cu_seqlens_on_cp_rank[1:].add_(seqlens_2) + cu_seqlens_on_cp_rank.cumsum_(dim=0) + return cu_seqlens_on_cp_rank + + +class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ - Attention implementation with context parallelism. - Split attention compute into multiple steps, and overlap current-step - compute with next-step communication. + Attention implementation with context parallelism. Exchange KV between CP ranks + with P2P in ring topology. Split attention compute into multiple steps, and overlap + current-step compute with next-step communication. """ @staticmethod @@ -1195,9 +1314,9 @@ def forward( k, v, cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_kv, max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, dropout_p, @@ -1211,6 +1330,8 @@ def forward( attn_bias, deterministic, use_fused_attention, + fp8, + fp8_meta, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -1224,8 +1345,24 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - + if qkv_format in ["bshd", "sbhd"]: + seq_dim = qkv_format.index("s") + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + else: + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) + pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) + max_seqlen_q = max_seqlen_q // cp_size + max_seqlen_kv = max_seqlen_kv // cp_size + cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size + cu_seqlens_q_per_step = [None for _ in range(cp_size)] + cu_seqlens_kv_per_step = [None for _ in range(cp_size)] + + assert qkv_format == "thd" or ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" if causal: if qkv_format == "bshd": # [b, s, np, hn] -> [b, 2, s//2, np, hn] @@ -1233,11 +1370,17 @@ def forward( elif qkv_format == "sbhd": # [s, b, np, hn] -> [2, s//2, b, np, hn] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] + total_tokens_kv = None if qkv_format != "thd" else k.shape[0] + # remove padded tokens at the end + k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]] if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " "and [1, h, sq, sk] for backward!" ) + assert ( + attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 + ), "Sequence length does not meet divisible requirements!" # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] attn_bias_ = attn_bias.view( *attn_bias.shape[:-2], @@ -1253,9 +1396,11 @@ def forward( assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" fa_optional_forward_kwargs = {} if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1] + fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) if _flash_attn_2_4_plus: fa_optional_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None # Flash Attn inputs q_inputs = [None, None] @@ -1272,8 +1417,48 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + else: + q_f16, k_f16, v_f16 = q, k, v + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + q_f16 = q + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + p2p_comm_buffers = [None for _ in range(cp_size)] - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + if use_fused_attention and qkv_format in ["bshd", "sbhd"]: + p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) + else: + p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] for i in range(cp_size + 1): @@ -1295,22 +1480,52 @@ def forward( batch_p2p_comm, ) - kv_inputs[i % 2] = p2p_comm_buffers[i] + if ( + not fp8 + or fp8_meta["recipe"].fp8_mha + or int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ): + kv_inputs[i % 2] = p2p_comm_buffers[i] + else: + # KV exchange is in BF16/FP16, cast received KV in each step + kv_inputs[i % 2] = cast_to_fp8( + p2p_comm_buffers[i], + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + ) + if fp8 and use_fused_attention: + fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] + fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] if causal: if i == 0: + if pad_between_seqs_q: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + if pad_between_seqs_kv: + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True + ) + else: + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view( - 2, k.shape[0], -1, *k.shape[-2:] + k.shape[0], -1, 2, *k.shape[-2:] ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) elif qkv_format == "thd": q_inputs[i % 2] = q if attn_bias is not None: @@ -1322,30 +1537,40 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_k, - cu_seqlens_q, - cu_seqlens_k, - q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1364,10 +1589,10 @@ def forward( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], - cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, dropout_p, softmax_scale, causal=True, @@ -1375,61 +1600,88 @@ def forward( **fa_optional_forward_kwargs, ) elif i <= rank: + if pad_between_seqs_q: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + if pad_between_seqs_kv: + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + False, + ) + else: + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous() elif qkv_format == "thd": q_inputs[i % 2] = q # [2, t, np, hn] -> [2, t/2, np, hn] kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_k, 0 + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 ) if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_k // 2, - cu_seqlens_q, - cu_seqlens_k // 2, - q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=( + None + if cu_seqlens_kv_padded is None + else cu_seqlens_kv_padded // 2 + ), + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) if qkv_format == "thd": # [2, t, np, hn] -> [2, t/2, np, hn] kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_k, 0 + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 ) else: # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] @@ -1437,7 +1689,7 @@ def forward( # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = [-1, -1] + fa_optional_forward_kwargs["window_size"] = (-1, -1) ( _, _, @@ -1451,10 +1703,10 @@ def forward( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], - cu_seqlens_q, - cu_seqlens_k // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_k // 2, + max_seqlen_kv // 2, dropout_p, softmax_scale, causal=False, @@ -1462,22 +1714,43 @@ def forward( **fa_optional_forward_kwargs, ) else: + if pad_between_seqs_q: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True + ) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) + if pad_between_seqs_kv: + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + True, + ) + else: + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size if use_fused_attention: if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_inputs[i % 2] = q[:, 1, ...].contiguous() - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view( - 2, k.shape[0], -1, *k.shape[-2:] + k.shape[0], -1, 2, *k.shape[-2:] ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_inputs[i % 2] = q[1].contiguous() - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) elif qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = torch.cat( @@ -1487,38 +1760,50 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q // 2, - max_seqlen_k, - cu_seqlens_q // 2, - cu_seqlens_k, - q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q // 2, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=( + None + if cu_seqlens_q_padded is None + else cu_seqlens_q_padded // 2 + ), + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: if qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) else: # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] q_inputs[i % 2] = ( @@ -1527,7 +1812,7 @@ def forward( # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = [-1, -1] + fa_optional_forward_kwargs["window_size"] = (-1, -1) ( _, _, @@ -1541,10 +1826,10 @@ def forward( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], - cu_seqlens_q // 2, - cu_seqlens_k, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], max_seqlen_q // 2, - max_seqlen_k, + max_seqlen_kv, dropout_p, softmax_scale, causal=False, @@ -1552,6 +1837,23 @@ def forward( **fa_optional_forward_kwargs, ) else: + if pad_between_seqs_q: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + if pad_between_seqs_kv: + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + True, + ) + else: + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size if use_fused_attention: if attn_bias is not None: idx = (rank - i) % cp_size @@ -1562,30 +1864,40 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_k, - cu_seqlens_q, - cu_seqlens_k, - q, - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, sq, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1604,10 +1916,10 @@ def forward( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], - cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, dropout_p, softmax_scale, causal=False, @@ -1625,8 +1937,16 @@ def forward( softmax_lse_per_step[i - 1].squeeze_(-1) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): + if fp8: + out_per_step[i - 1] = cast_from_fp8( + out_per_step[i - 1], + fp8_meta["scaling_fwd"], + META_O_CP, + fp8_dtype_forward, + TE_DType[torch.float32], + ) if i == 1: - out = torch.empty_like(q).zero_() + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": # [b, np, sq] -> [b, np, 2, sq//2] @@ -1640,7 +1960,10 @@ def forward( else: if qkv_format == "thd": tex.thd_second_half_lse_correction( - softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q, q.size(0) + softmax_lse, + softmax_lse_per_step[i - 1], + cu_seqlens_q_padded, + max_seqlen_q, ) else: flash_attn_fwd_softmax_lse_correction( @@ -1653,8 +1976,6 @@ def forward( torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) softmax_lse = softmax_lse.to(torch.float) - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") for i in range(cp_size): if qkv_format == "bshd": out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) @@ -1678,11 +1999,9 @@ def forward( out_per_step[i], softmax_lse, softmax_lse_per_step[i], - cu_seqlens_q, + cu_seqlens_q_padded, False, ) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" else: if qkv_format in ["bshd", "sbhd"]: flash_attn_fwd_out_correction( @@ -1698,11 +2017,9 @@ def forward( out_per_step[i], softmax_lse, softmax_lse_per_step[i], - cu_seqlens_q, + cu_seqlens_q_padded, True, ) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" kv = p2p_comm_buffers[-1] if use_fused_attention: @@ -1713,23 +2030,66 @@ def forward( else: out = out.view(-1, *out.shape[-2:]) + if fp8 and use_fused_attention: + amax_cp_fwd = amax_per_step.amax(dim=1) + fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] + fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] + + out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) + if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): + out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) + + if fp8 and fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor( + data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) + else: + out_ret = out_f16 + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, kv_save, out_save = q, kv, out_fp8 + fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() + fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + elif fp8 and fp8_meta["recipe"].fp8_mha: + kv_fp8 = Float8Tensor( + data=kv, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=k_fp8.dtype, + ) + q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + else: + q_save, kv_save, out_save = q_f16, kv, out_f16 + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + ctx.save_for_backward( - q, - kv, - out, + q_save, + kv_save, + out_save, softmax_lse, - cu_seqlens_q, - cu_seqlens_k, cu_seqlens_q_padded, cu_seqlens_kv_padded, + fp8_fwd_scales, + fp8_fwd_scale_invs, + *cu_seqlens_q_per_step, + *cu_seqlens_kv_per_step, *rng_states, *attn_biases, ) ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks ctx.dropout_p = dropout_p + ctx.total_tokens_kv = total_tokens_kv ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k + ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type @@ -1737,24 +2097,31 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention - return out + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + return out_ret @staticmethod def backward(ctx, dout): - (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6] - (cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[6:8] cp_size = get_distributed_world_size(ctx.cp_group) - rng_states = ctx.saved_tensors[8 : 8 + cp_size] - attn_biases = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] - rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] + (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] + cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + if ctx.qkv_format in ["bshd", "sbhd"]: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] + else: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if attn_biases[0] is not None: # [b, np, sq, 2*cp, sk//(2*cp)] @@ -1770,7 +2137,9 @@ def backward(ctx, dout): if causal: if ctx.qkv_format == "thd": - softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) + softmax_lse_ = tex.thd_read_second_half_lse( + softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q + ) else: # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view( @@ -1780,20 +2149,61 @@ def backward(ctx, dout): if ctx.use_fused_attention: # [b, np, sq//2] -> [b, np, sq//2, 1] softmax_lse_.unsqueeze_(-1) - if ctx.use_fused_attention: # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) + + if ctx.fp8: + if ctx.use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_dqkv_dtype = fp8_dtype_backward + fused_attn_backend = FusedAttnBackend["FP8"] + dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) + dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) + dkv_fp8_ = torch.empty_like(dkv_fp8) + dout_dtype = dout.dtype + if ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + dout = dout._data + else: + dout = cast_to_fp8( + dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ) + p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] + fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] + fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] + fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] + fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] + fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] + fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]] + dq = torch.empty_like(q) + if ctx.qkv_format == "thd" and causal: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + p2p_comm_buffers = [ + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + ] + p2p_comm_buffers[0][0].copy_(kv) + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + out = out.view(*q.shape) dout = dout.view(*q.shape) - # Flash Attn outputs - dq = torch.empty_like(q) - - p2p_comm_buffers = [ - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), - ] - p2p_comm_buffers[0][0].copy_(kv) send_recv_reqs = [] fa_optional_backward_kwargs = {} @@ -1809,18 +2219,40 @@ def backward(ctx, dout): send_tensor = p2p_comm_buffers[i % 2] recv_tensor = p2p_comm_buffers[(i + 1) % 2] - if i == 0: - send_tensor = send_tensor[0] - recv_tensor = recv_tensor[0] - if i == (cp_size - 1): - send_tensor = send_tensor[1] - recv_tensor = recv_tensor[1] - - send_recv_reqs = flash_attn_p2p_communicate( - rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm - ) + if ctx.fp8: + if i < cp_size - 1: + send_recv_reqs = flash_attn_p2p_communicate( + rank, + send_tensor[0], + send_dst, + recv_tensor[0], + recv_src, + ctx.cp_group, + batch_p2p_comm, + ) + else: + dkv_a2a_req = torch.distributed.all_to_all_single( + dkv_fp8, + dkv_fp8_, + group=ctx.cp_group, + async_op=True, + ) + send_recv_reqs = [dkv_a2a_req] + else: + if i == 0: + send_tensor = send_tensor[0] + recv_tensor = recv_tensor[0] + if i == (cp_size - 1): + send_tensor = send_tensor[1] + recv_tensor = recv_tensor[1] + send_recv_reqs = flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm + ) kv = p2p_comm_buffers[i % 2][0] + if ctx.fp8 and ctx.use_fused_attention: + fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] + fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] # In reversed order of fwd if causal: if i == (cp_size - 1): @@ -1828,38 +2260,45 @@ def backward(ctx, dout): if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] out_ = out.view(out.shape[0], -1, *out.shape[-2:]) dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_ = q.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) # [2, sq//2, b, np, hn] -> [sq, b, np, hn] out_ = out.view(-1, *out.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:]) elif ctx.qkv_format == "thd": q_, kv_, out_, dout_ = q, kv, out, dout - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_k, - cu_seqlens_q, - cu_seqlens_k, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, @@ -1867,11 +2306,13 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.empty_like(q_) + dq_ = torch.zeros_like(q_) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_ = kv.view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) @@ -1879,7 +2320,7 @@ def backward(ctx, dout): out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = [-1, 0] + fa_optional_backward_kwargs["window_size"] = (-1, 0) _flash_attn_backward( dout_, q_, @@ -1890,10 +2331,10 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], - cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, - ctx.max_seqlen_k, + ctx.max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, True, @@ -1905,40 +2346,47 @@ def backward(ctx, dout): if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_ = kv[:, :, 0, ...].contiguous() + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0, ...].contiguous() # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] out_ = out.view(out.shape[0], -1, *out.shape[-2:]) dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_ = q.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] - kv_ = kv[:, 0, ...].contiguous() + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0].contiguous() # [2, sq//2, b, np, hn] -> [sq, b, np, hn] out_ = out.view(-1, *out.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:]) elif ctx.qkv_format == "thd": q_, out_, dout_ = q, out, dout # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_k // 2, - cu_seqlens_q, - cu_seqlens_k // 2, + ctx.max_seqlen_kv // 2, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 @@ -1948,14 +2396,16 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type="padding" if padding else "no_mask", attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.empty_like(q_) + dq_ = torch.zeros_like(q_) if ctx.qkv_format == "thd": # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) else: # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) @@ -1964,7 +2414,7 @@ def backward(ctx, dout): out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = [-1, -1] + fa_optional_backward_kwargs["window_size"] = (-1, -1) _flash_attn_backward( dout_, q_, @@ -1975,10 +2425,10 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], - cu_seqlens_q, - cu_seqlens_k // 2, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, - ctx.max_seqlen_k // 2, + ctx.max_seqlen_kv // 2, ctx.dropout_p, ctx.softmax_scale, False, @@ -1990,42 +2440,49 @@ def backward(ctx, dout): if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_ = q[:, 1, ...].contiguous() - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] out_ = out[:, 1, ...].contiguous() dout_ = dout[:, 1, ...].contiguous() elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_ = q[1].contiguous() - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] out_ = out[1].contiguous() dout_ = dout[1].contiguous() elif ctx.qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) + q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) + out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) + dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) kv_ = kv - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, - ctx.max_seqlen_k, - cu_seqlens_q // 2, - cu_seqlens_k, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=( None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 ), @@ -2035,27 +2492,29 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type="padding" if padding else "no_mask", attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: if ctx.qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) else: # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) - dq_ = torch.empty_like(q_) + dq_ = torch.zeros_like(q_) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_ = kv.view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) if ctx.qkv_format == "thd": - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) + out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) + dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) else: # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = [-1, -1] + fa_optional_backward_kwargs["window_size"] = (-1, -1) _flash_attn_backward( dout_, q_, @@ -2066,10 +2525,10 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], - cu_seqlens_q // 2, - cu_seqlens_k, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q // 2, - ctx.max_seqlen_k, + ctx.max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, False, @@ -2078,23 +2537,26 @@ def backward(ctx, dout): ) else: if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_k, - cu_seqlens_q, - cu_seqlens_k, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], q, - kv[0], - kv[1], + kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], + kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], out, dout, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, @@ -2102,11 +2564,13 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, sq, np, hn] -> [b*sq, np, hn] q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.empty_like(q_) + dq_ = torch.zeros_like(q_) # [2, b, sk, np, hn] -> [2, b*sk, np, hn] kv_ = kv.view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) @@ -2114,7 +2578,7 @@ def backward(ctx, dout): out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = [-1, -1] + fa_optional_backward_kwargs["window_size"] = (-1, -1) _flash_attn_backward( dout_, q_, @@ -2125,16 +2589,19 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], - cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, - ctx.max_seqlen_k, + ctx.max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, False, + rng_state=rng_states[cp_size - i - 1], **fa_optional_backward_kwargs, ) + if ctx.fp8: + dq = dq_fp8[(rank + i + 1) % cp_size] if i >= (cp_size - rank - 1) or not causal: # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal # [b*sq, np, hn] -> [b, sq, np, hn] if not causal @@ -2147,135 +2614,665 @@ def backward(ctx, dout): # [b*sq//2, np, hn] -> [sq//2, b, np, hn] dq_ = dq_.view(-1, *dq.shape[-3:]) - if causal: + if ctx.fp8: + if i >= (cp_size - rank - 1) or not causal: + dq.copy_(dq_) + else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].fill_(0) + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[0].fill_(0) + dq[1].copy_(dq_) + elif causal: if i > (cp_size - rank - 1): dq.add_(dq_) elif i == (cp_size - rank - 1): if rank == (cp_size - 1): dq.copy_(dq_) else: - if ctx.qkv_format == "bshd": - dq[:, 0, ...].copy_(dq_[:, 0, ...]) - dq[:, 1, ...].add_(dq_[:, 1, ...]) - elif ctx.qkv_format == "sbhd": - dq[0].copy_(dq_[0]) - dq[1].add_(dq_[1]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add") - elif i > 0: - if ctx.qkv_format == "bshd": - dq[:, 1, ...].add_(dq_) - elif ctx.qkv_format == "sbhd": - dq[1].add_(dq_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add") - else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].copy_(dq_[:, 0, ...]) + dq[:, 1, ...].add_(dq_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dq[0].copy_(dq_[0]) + dq[1].add_(dq_[1]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add") + elif i > 0: + if ctx.qkv_format == "bshd": + dq[:, 1, ...].add_(dq_) + elif ctx.qkv_format == "sbhd": + dq[1].add_(dq_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add") + else: + if ctx.qkv_format == "bshd": + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[1].copy_(dq_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy") + else: + if i == 0: + dq.copy_(dq_) + else: + dq.add_(dq_) + + if attn_dbias is not None: + idx = (rank + i + 1) % cp_size + if i == (cp_size - 1) or not causal: + # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) + attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) + attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) + elif i >= (cp_size - rank - 1): + # [b, np, sq, sk//(2*cp)] + attn_dbias[..., idx, :].copy_(dbias_) + else: + # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) + attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) + attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) + + # wait until dKV is received + for req in send_recv_reqs: + req.wait() + + if ctx.fp8: + if i < cp_size - 1: + dkv = dkv_fp8_[(rank + i + 1) % cp_size] + else: + dkv = dkv_fp8[(rank + i + 1) % cp_size] + else: + dkv = p2p_comm_buffers[(i + 1) % 2][1] + if ctx.use_fused_attention: + dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) + if ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] + dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) + elif ctx.qkv_format == "sbhd": + # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn] + dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:]) + else: + # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal + # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal + dkv_ = dkv_.view(*dkv.shape) + + if ctx.fp8: + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + dkv[:, :, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + dkv[:, 1, ...].fill_(0) + else: + dkv.copy_(dkv_) + elif causal: + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) + dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_[:, 0, ...]) + dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy") + else: + dkv.add_(dkv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none") + else: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none") + elif i > 0: + dkv.add_(dkv_) + else: + dkv.copy_(dkv_) + else: + if i == 0: + dkv.copy_(dkv_) + else: + dkv.add_(dkv_) + + if ctx.fp8 and ctx.use_fused_attention: + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0] + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1] + if ctx.qkv_format in ["bshd", "sbhd"]: + # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or + # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] + dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) + dq, dkv = [ + cast_from_fp8( + x, + ctx.fp8_meta["scaling_bwd"], + META_DQKV_CP, + fp8_dtype_backward, + TE_DType[torch.float32], + ) + for x in [dq_fp8, dkv_fp8] + ] + dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + + if causal: + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + dq = dq.view(-1, *dq.shape[-3:]) + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + + if ctx.qkv_format == "thd": + dkv_ = torch.empty( + 2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device + ) + dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv) + dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) + dkv = dkv_ + + if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: + dq, dkv = [ + cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) + for x in [dq, dkv] + ] + dq, dk, dv = [ + Float8Tensor( + data=x, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=dout_dtype, + ) + for x in [dq, dkv[0], dkv[1]] + ] + else: + dk, dv = dkv[0], dkv[1] + + if attn_dbias is not None: + # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] + attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + attn_dbias, + None, + None, + None, + None, + ) + + +@torch.compile +def get_seq_chunk_ids_to_all_gathered_kv( + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device +): + """Compute sequence chunk ids to the all-gathered KV.""" + seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv + seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left) + seqlen = seq_end_idx - seq_start_idx + num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv + chunk_ids = torch.arange( + local_chunk_id - num_chunks + 1, + local_chunk_id + 1, + dtype=torch.int32, + device=device, + ) + chunk_ids_to_all_gathered_kv = torch.where( + chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 + ) + return chunk_ids_to_all_gathered_kv + + +class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): + """ + Attention implementation with context parallelism. + KV all-gather between CP ranks is exposed. + """ + + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_stream, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + cp_size = get_distributed_world_size(cp_group) + rank = get_distributed_rank(cp_group) + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + assert causal and not padding, f"{attn_mask_type} mask type is not supported!" + if use_fused_attention and causal and "bottom_right" not in attn_mask_type: + attn_mask_type = attn_mask_type + "_bottom_right" + + assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" + assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert ( + use_fused_attention or _flash_attn_2_3_plus + ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + fa_optional_forward_kwargs = {} + if _flash_attn_2_4_plus: + fa_optional_forward_kwargs["alibi_slopes"] = None + + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + + max_seqlen_q = max_seqlen_q // (2 * cp_size) + max_seqlen_kv = max_seqlen_kv // (2 * cp_size) + cu_seqlens_q = cu_seqlens_q // (2 * cp_size) + cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size) + cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size) + + if causal: + if qkv_format == "bshd": + # [b, s, np, hn] -> [b, 2, s//2, np, hn] + q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:]) + # [b, s, np, hn] -> [s, b, np, hn] + k, v = [x.transpose(0, 1).contiguous() for x in [k, v]] + elif qkv_format == "sbhd": + # [s, b, np, hn] -> [2, s//2, b, np, hn] + q = q.view(2, q.shape[0] // 2, *q.shape[1:]) + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + + k_ag, _ = gather_along_first_dim(k, cp_group) + v_ag, _ = gather_along_first_dim(v, cp_group) + cp_stream.wait_stream(torch.cuda.current_stream()) + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] + chunk_ids_to_kv_ag_per_step = [None, None] + out_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] + out = torch.empty_like(q) + + for i in range(len(local_seq_chunk_ids) + 1): + if i < len(local_seq_chunk_ids): + with torch.cuda.stream(flash_attn_streams[i]): + chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + ( + max_seqlen_kv * cp_size * 2 + if (window_size is None or window_size[0] == -1) + else window_size[0] + ), + k.device, + ) + chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag + num_kv_chunks = chunk_ids_to_kv_ag.numel() + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_ = q[:, i].contiguous() + # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] + k_ = ( + torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) + .movedim(2, 0) + .contiguous() + .view(k.shape[1], -1, *k.shape[-2:]) + ) + v_ = ( + torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) + .movedim(2, 0) + .contiguous() + .view(v.shape[1], -1, *v.shape[-2:]) + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q[i].contiguous() + # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] + k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( + -1, *k.shape[-3:] + ) + v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( + -1, *v.shape[-3:] + ) + if use_fused_attention: + out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv * num_kv_chunks, + cu_seqlens_q, + cu_seqlens_kv * num_kv_chunks, + q_, + k_, + v_, + TE_DType[q.dtype], + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, + window_size=window_size, + ) + else: + q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] + _, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = ( + _flash_attn_forward( + q_, + k_, + v_, + cu_seqlens_q, + cu_seqlens_kv * num_kv_chunks, + max_seqlen_q, + max_seqlen_kv * num_kv_chunks, + dropout_p, + softmax_scale, + causal=True, + return_softmax=False, + window_size=window_size, + **fa_optional_forward_kwargs, + ) + ) + + if i > 0: + with torch.cuda.stream(flash_attn_streams[i - 1]): + if qkv_format == "bshd": + out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1])) + elif qkv_format == "sbhd": + out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1])) + + torch.cuda.current_stream().wait_stream(cp_stream) + + if use_fused_attention: + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) + else: + out = out.view(-1, *out.shape[-2:]) + + ctx.save_for_backward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *chunk_ids_to_kv_ag_per_step, + *out_per_step, + *softmax_lse_per_step, + *rng_states, + ) + ctx.cp_group = cp_group + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_mask_type = attn_mask_type + ctx.attn_bias_type = attn_bias_type + ctx.deterministic = deterministic + ctx.use_fused_attention = use_fused_attention + ctx.window_size = window_size + return out + + @staticmethod + def backward(ctx, dout): + cp_size = get_distributed_world_size(ctx.cp_group) + rank = get_distributed_rank(ctx.cp_group) + + (q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ( + ctx.saved_tensors[:7] + ) + chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9] + out_per_step = ctx.saved_tensors[9:11] + softmax_lse_per_step = ctx.saved_tensors[11:13] + rng_states = ctx.saved_tensors[13:15] + + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + + dout = dout.view_as(q) + dq = torch.empty_like(q) + dk = torch.zeros( + (2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device + ) + dv = torch.zeros_like(dk) + dq_per_step = [None, None] + dk_per_step = [None, None] + dv_per_step = [None, None] + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream] + # synchronize dkv update across steps + dkv_update_done = torch.cuda.Event() + + k_ag, _ = gather_along_first_dim(k, ctx.cp_group) + v_ag, _ = gather_along_first_dim(v, ctx.cp_group) + ctx.cp_stream.wait_stream(torch.cuda.current_stream()) + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] + + fa_optional_backward_kwargs = {} + if _flash_attn_2_4_plus: + fa_optional_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + + for i in range(len(local_seq_chunk_ids) + 1): + if i < len(local_seq_chunk_ids): + with torch.cuda.stream(flash_attn_streams[i]): + chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i] + num_kv_chunks = chunk_ids_to_kv_ag.numel() + out_ = out_per_step[i] + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_ = q[:, i].contiguous() + # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] + k_ = ( + torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) + .movedim(2, 0) + .contiguous() + .view(k.shape[1], -1, *k.shape[-2:]) + ) + v_ = ( + torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) + .movedim(2, 0) + .contiguous() + .view(v.shape[1], -1, *v.shape[-2:]) + ) + dout_ = dout[:, i].contiguous().view_as(out_) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q[i].contiguous() + # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] + k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( + -1, *k.shape[-3:] + ) + v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( + -1, *v.shape[-3:] + ) + dout_ = dout[i].contiguous().view_as(out_) + if ctx.use_fused_attention: + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + torch.empty_like(x) for x in [q_, k_, v_] + ] + aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv * num_kv_chunks, + cu_seqlens_q, + cu_seqlens_kv * num_kv_chunks, + q_, + k_, + v_, + out_, + dout_, + TE_DType[q.dtype], + TE_DType[k.dtype], + aux_ctx_tensors, + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + window_size=ctx.window_size, + ) + else: + q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + torch.empty_like(x) for x in [q_, k_, v_] + ] + _flash_attn_backward( + dout_, + q_, + k_, + v_, + out_, + softmax_lse_per_step[i], + dq_per_step[i], + dk_per_step[i], + dv_per_step[i], + cu_seqlens_q, + cu_seqlens_kv * num_kv_chunks, + ctx.max_seqlen_q, + ctx.max_seqlen_kv * num_kv_chunks, + ctx.dropout_p, + ctx.softmax_scale, + True, + window_size=ctx.window_size, + rng_state=rng_states[i], + **fa_optional_backward_kwargs, + ) + + if i > 0: + with torch.cuda.stream(flash_attn_streams[i - 1]): + chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1] + num_kv_chunks = chunk_ids_to_kv_ag.numel() if ctx.qkv_format == "bshd": - dq[:, 1, ...].copy_(dq_) + dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1])) + dk_per_step[i - 1] = ( + dk_per_step[i - 1] + .view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:]) + .movedim(0, 2) + .contiguous() + ) + dv_per_step[i - 1] = ( + dv_per_step[i - 1] + .view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:]) + .movedim(0, 2) + .contiguous() + ) elif ctx.qkv_format == "sbhd": - dq[1].copy_(dq_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy") - else: - if i == 0: - dq.copy_(dq_) - else: - dq.add_(dq_) - - if attn_dbias is not None: - idx = (rank + i + 1) % cp_size - if i == (cp_size - 1) or not causal: - # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] - dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) - attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) - attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) - elif i >= (cp_size - rank - 1): - # [b, np, sq, sk//(2*cp)] - attn_dbias[..., idx, :].copy_(dbias_) - else: - # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] - dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) - attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) - attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) - - # wait until dKV is received - for req in send_recv_reqs: - req.wait() + dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1])) + dk_per_step[i - 1] = dk_per_step[i - 1].view( + num_kv_chunks, -1, *k.shape[-3:] + ) + dv_per_step[i - 1] = dv_per_step[i - 1].view( + num_kv_chunks, -1, *v.shape[-3:] + ) - dkv = p2p_comm_buffers[(i + 1) % 2][1] - if ctx.use_fused_attention: - dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] - dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) - elif ctx.qkv_format == "sbhd": - # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn] - dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:]) - else: - # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal - # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal - dkv_ = dkv_.view(*dkv.shape) + # wait until dkv update of last step is done + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1]) + dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) - if causal: - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) - dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_[:, 0, ...]) - dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy") - else: - dkv.add_(dkv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none") - else: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none") - elif i > 0: - dkv.add_(dkv_) - else: - dkv.copy_(dkv_) - else: - if i == 0: - dkv.copy_(dkv_) - else: - dkv.add_(dkv_) + torch.cuda.current_stream().wait_stream(ctx.cp_stream) - if causal: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - dq = dq.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - dq = dq.view(-1, *q.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:]) + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - if attn_dbias is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] - attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) + if ctx.qkv_format == "bshd": + dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) + dk = dk.transpose(0, 1).contiguous() + dv = dv.transpose(0, 1).contiguous() + elif ctx.qkv_format == "sbhd": + dq = dq.view(-1, *dq.shape[-3:]) return ( None, dq, - dkv[0], - dkv[1], + dk, + dv, + None, None, None, None, @@ -2290,7 +3287,6 @@ def backward(ctx, dout): None, None, None, - attn_dbias, None, None, ) @@ -2302,15 +3298,16 @@ def attn_forward_func_with_cp( k, v, cu_seqlens_q, - cu_seqlens_k, + cu_seqlens_kv, max_seqlen_q, - max_seqlen_k, + max_seqlen_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, dropout_p, cp_group, cp_global_ranks, cp_stream, + cp_comm_type, softmax_scale=None, qkv_format="bshd", attn_mask_type="causal", @@ -2318,8 +3315,14 @@ def attn_forward_func_with_cp( attn_bias=None, deterministic=False, use_fused_attention=False, + window_size=None, + fp8=False, + fp8_meta=None, ) -> torch.Tensor: - """Attention implementation with context parallelism""" + """ + Attention implementation with context parallelism. + """ + assert qkv_format in [ "bshd", "sbhd", @@ -2340,29 +3343,67 @@ def attn_forward_func_with_cp( """Attention bias is only supported with FusedAttention and "causal" """ """or "no_mask" mask types!""" ) - out = AttnFuncWithCP.apply( - is_training, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_global_ranks, - cp_stream, - softmax_scale, - qkv_format, - attn_mask_type, - attn_bias_type, - attn_bias, - deterministic, - use_fused_attention, + assert ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" + + sliding_window_attn = ( + window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) ) + + if sliding_window_attn or cp_comm_type == "all_gather": + out = AttnFuncWithCPAndKVAllGather.apply( + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_stream, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + ) + elif cp_comm_type == "p2p": + out = AttnFuncWithCPAndKVP2P.apply( + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_global_ranks, + cp_stream, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + fp8, + fp8_meta, + ) + else: + raise ValueError(f"Unsupported communication type: {cp_comm_type}!") + return out @@ -2682,6 +3723,7 @@ class UnfusedDotProductAttention(torch.nn.Module): def __init__( self, softmax_scale: float, + attention_type: str = "self", attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, @@ -2689,6 +3731,7 @@ def __init__( super().__init__() self.softmax_scale = softmax_scale + self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number @@ -2728,6 +3771,58 @@ def forward( query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] + batch_size, max_seqlen_q, max_seqlen_kv = ( + query_layer.shape[1], + query_layer.shape[0], + key_layer.shape[0], + ) + if "padding" in attn_mask_type: + if self.attention_type == "self": + assert attention_mask.shape == ( + batch_size, + 1, + 1, + max_seqlen_q, + ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + assert ( + len(attention_mask) == 2 + and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) + and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) + ), ( + "attention_mask should be a tuple of two tensors with shapes " + "[b, 1, 1, sq] and [b, 1, 1, skv]!" + ) + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + mask = attention_mask.squeeze(1).logical_not() + actual_seqlens_q = mask[:, :, 0].sum(dim=1) + actual_seqlens_kv = mask[:, 0, :].sum(dim=1) + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv + ) + if attn_mask_type == "padding_causal": + attention_mask = torch.logical_or( + torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), + attention_mask, + ) + if attn_mask_type == "padding_causal_bottom_right": + attention_mask = torch.logical_or( + torch.where( + mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + < 0, + 1, + 0, + ), + attention_mask, + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -2782,7 +3877,7 @@ def forward( key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=scale, - ) + ).view(*output_size) elif core_attention_bias_type == "pre_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" @@ -2790,10 +3885,7 @@ def forward( query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] ) - matmul_result = ( - matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3]) - + core_attention_bias - ).view(-1, output_size[2], output_size[3]) + matmul_result = matmul_result.view(*output_size) + core_attention_bias matmul_result *= scale elif core_attention_bias_type in ["post_scale_bias", "alibi"]: @@ -2804,6 +3896,8 @@ def forward( output_size[1], output_size[2], output_size[3], + actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, + actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) @@ -2814,26 +3908,21 @@ def forward( beta=0.0, alpha=scale, ) - matmul_result = ( - ( - matmul_result.view( - output_size[0], output_size[1], output_size[2], output_size[3] - ) - + core_attention_bias - ) - .view(-1, output_size[2], output_size[3]) - .to(dtype=query_layer.dtype) + matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to( + dtype=query_layer.dtype ) - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # attention scores and attention mask [b, np, sq, sk] softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( - attention_scores, attention_mask, attn_mask_type, softmax_scale + matmul_result, attention_mask, attn_mask_type, softmax_scale ) + # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q) + # the columns (pad tokens from k) are already zeroed out during softmax + if "padding" in attn_mask_type: + attention_probs = attention_probs.masked_fill(attention_mask, 0) + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -2929,7 +4018,7 @@ def get_qkv_layout( qkv_format: str, default = `sbhd` Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length dimension, `b` batch size, `h` the number of attention heads, - `d` head size, and `t` the total number of sequences in a batch, i.e. + `d` head size, and `t` the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. Returns @@ -2959,12 +4048,14 @@ def run_iteratively(q, k, v): stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) stride = k.stride() - check_strides_kv = all(stride == x.stride() for x in [k, v]) + check_strides_kv = torch.equal( + torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1] + ) shape = q.shape check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) shape = k.shape - check_shapes_kv = all(shape == x.shape for x in [k, v]) + check_shapes_kv = shape[:-1] == v.shape[:-1] last_dim_size = q.shape[-1] check_last_dim_offsets_qkv = all( @@ -3044,31 +4135,28 @@ def check_set_window_size( """ orig_window_size = window_size if "causal" in attn_mask_type: - if orig_window_size is None or ( - orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0] - ): + if orig_window_size is None: window_size = (-1, 0) - warnings.warn( - "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type - ) - elif orig_window_size[0] >= 0: + elif orig_window_size == (-1, -1) or ( + orig_window_size[0] >= 0 and orig_window_size[1] != 0 + ): window_size = (orig_window_size[0], 0) warnings.warn( "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type ) - else: + elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): assert False, ( "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type ) elif attn_mask_type in ["no_mask", "padding", "arbitrary"]: - if orig_window_size is None or ( - orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0] - ): + if orig_window_size is None: + window_size = (-1, -1) + elif orig_window_size == (-1, 0): window_size = (-1, -1) warnings.warn( "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type ) - elif orig_window_size[0] < 0 or orig_window_size[0] < 0: + elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): assert False, ( "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type ) @@ -3124,6 +4212,7 @@ def forward( cp_group: Optional[dist_group_type] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", ) -> torch.Tensor: """flash-attn fprop""" @@ -3139,7 +4228,8 @@ def forward( qkv_layout in QKVLayouts ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" - context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) + cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group) + context_parallel = cp_size > 1 qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -3166,6 +4256,8 @@ def forward( if qkv_format in ["sbhd", "bshd"]: max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size if not context_parallel: # [b * s, h, d] query_layer, key_layer, value_layer = [ @@ -3229,10 +4321,6 @@ def forward( max_seqlen_kv = seqlens_kv.max().item() if context_parallel: - assert window_size in ( - (-1, -1), - (-1, 0), - ), "Sliding window attention is not supported with context parallelism." assert ( alibi_slopes is None ), "Alibi slope bias addition is not supported with context parallelism." @@ -3246,16 +4334,18 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - None, - None, + cu_seqlens_q, + cu_seqlens_kv, self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, cp_stream, + cp_comm_type, softmax_scale=self.softmax_scale, qkv_format="bshd" if qkv_format == "sbhd" else qkv_format, attn_mask_type=attn_mask_type, deterministic=self.deterministic, + window_size=window_size, ) else: @@ -3275,6 +4365,8 @@ def forward( fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None output = flash_attn_forward_func( query_layer, key_layer, @@ -3294,10 +4386,12 @@ def forward( if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) - output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous() + output = ( + output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous() + ) elif qkv_format == "bshd": # (bs)hd -> bs(hd) - output = output.view(batch_size, max_seqlen_q, -1).contiguous() + output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous() elif qkv_format == "thd": # thd -> t(hd) output = output.view(output.shape[0], -1).contiguous() @@ -3361,9 +4455,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc_qkvpacked") if fp8: - logger.debug("Running forward in FP8") if fp8_meta["recipe"].fp8_mha: assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv @@ -3447,7 +4539,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", qkv.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( is_training, max_seqlen, @@ -3500,7 +4591,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc_qkvpacked") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -3554,7 +4644,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -3620,7 +4709,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dqkv_fp8.shape) else: - logger.debug("Running backward in %s", qkv.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(qkv.dtype) dqkv, *rest = fused_attn_bwd_qkvpacked( @@ -3738,9 +4826,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc_kvpacked") if fp8: - logger.debug("Running forward in FP8") if fp8_meta["recipe"].fp8_mha: assert isinstance(q, Float8Tensor) and isinstance( kv, Float8Tensor @@ -3837,7 +4923,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", q.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( is_training, max_seqlen_q, @@ -3901,7 +4986,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc_kvpacked") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -3959,7 +5043,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -4044,7 +5127,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dkv_fp8.shape) else: - logger.debug("Running backward in %s", q.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(q.dtype) dq, dkv, *rest = fused_attn_bwd_kvpacked( @@ -4175,9 +5257,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc") if fp8: - logger.debug("Running forward in FP8") fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if fp8_meta["recipe"].fp8_mha: @@ -4345,7 +5425,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", q.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -4419,7 +5498,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -4481,7 +5559,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -4619,7 +5696,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dv_fp8.shape) else: - logger.debug("Running backward in %s", q.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(q.dtype) dq, dk, dv, *rest = fused_attn_bwd( @@ -4760,7 +5836,6 @@ def __init__( ) -> None: super().__init__() - self.logger = logging.getLogger("FusedAttention") self.softmax_scale = softmax_scale self.attention_dropout = attention_dropout self.attention_dropout_ctx = attention_dropout_ctx @@ -4815,6 +5890,7 @@ def forward( cp_group: Optional[dist_group_type] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: @@ -4834,7 +5910,8 @@ def forward( qkv_layout in QKVLayouts ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" - context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) + cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group) + context_parallel = cp_size > 1 qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -4851,6 +5928,8 @@ def forward( query_layer.shape[1], key_layer.shape[1], ) + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size if "padding" in attn_mask_type: assert not context_parallel, "Padding mask not supported with context parallelism!" @@ -4898,9 +5977,21 @@ def forward( and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen) ) + if fp8: + assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( + f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" + " is required for FP8 attention!" + ) + assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" + assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across TP+CP group is necessary when using context parallelism with" + " FP8!" + ) + if context_parallel: assert ( - fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8 + or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen ), f"{fused_attention_backend} does not work with context parallelism!" assert core_attention_bias_type not in [ "alibi" @@ -4924,23 +6015,20 @@ def forward( cp_group, cp_global_ranks, cp_stream, + cp_comm_type, softmax_scale=self.softmax_scale, qkv_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=core_attention_bias_type, attn_bias=core_attention_bias, + deterministic=self.deterministic, use_fused_attention=True, + window_size=window_size, + fp8=fp8, + fp8_meta=fp8_meta, ) else: with self.attention_dropout_ctx(): - if fp8: - assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( - f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" - " is required for FP8 attention!" - ) - assert ( - fp8_meta is not None - ), "FP8 metadata fp8_meta is required for FP8 attention!" output = FusedAttnFunc.apply( self.training, max_seqlen_q, @@ -4994,8 +6082,9 @@ class DotProductAttention(TransformerEngineBaseModule): ---------- num_attention_heads : int number of attention heads in the transformer layer. - kv_channels : int - number of key-query-value channels per attention head. + kv_channels : Union[int, Tuple[int, int]] + the head size in key and value tensors. If the same, :attr:`kv_channels` can be + an integer; if not, :attr:`kv_channels` should be a tuple of two integers. num_gqa_groups : Optional[int] = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -5048,7 +6137,7 @@ class DotProductAttention(TransformerEngineBaseModule): qkv_format: str, default = `sbhd` dimension format for `query_layer`, `key_layer` and `value_layer`, {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size, - `h` the number of heads, `d` head size, and `t` the total number of sequences + `h` the number of heads, `d` head size, and `t` the total number of tokens in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats are used for when sequences in a batch are of equal length or padded to equal length, and the `thd` format is used for when sequences in a batch @@ -5057,7 +6146,7 @@ class DotProductAttention(TransformerEngineBaseModule): For that, please use `get_qkv_layout` to gain the layout information. softmax_scale: Optional[float], default = `None` softmax scale for the attention scores. If `None`, defaults to - `1.0 / math.sqrt(kv_channels)`. + `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. Parallelism parameters ---------------------- @@ -5076,12 +6165,15 @@ class DotProductAttention(TransformerEngineBaseModule): compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels. + cp_comm_type : str + inter-gpu communication type for context parallelism. + Can be "p2p" or "all_gather". """ def __init__( self, num_attention_heads: int, - kv_channels: int, + kv_channels: Union[int, Tuple[int, int]], num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, qkv_format: str = "sbhd", @@ -5096,11 +6188,15 @@ def __init__( cp_group: Optional[dist_group_type] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, ) -> None: super().__init__() self.logger = logging.getLogger("DotProductAttention") + self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) self.qkv_format = qkv_format attn_mask_type = attn_mask_type.replace(",", "_") if attn_mask_type == "causal_padding": @@ -5120,11 +6216,17 @@ def __init__( self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream + self.cp_comm_type = cp_comm_type - self.hidden_size_per_attention_head = kv_channels + self.hidden_size_per_attention_head_k = ( + kv_channels if isinstance(kv_channels, int) else kv_channels[0] + ) + self.hidden_size_per_attention_head_v = ( + kv_channels if isinstance(kv_channels, int) else kv_channels[1] + ) self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups - self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) + self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) assert ( num_attention_heads % self.num_gqa_groups == 0 @@ -5139,7 +6241,9 @@ def __init__( attention_dropout_ctx = self.rng_states_tracker.fork if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(kv_channels) + softmax_scale = 1.0 / math.sqrt( + kv_channels if isinstance(kv_channels, int) else kv_channels[0] + ) self.deterministic = ( not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) @@ -5194,7 +6298,10 @@ def __init__( ) self.unfused_attention = UnfusedDotProductAttention( - softmax_scale, **attn_kwargs, layer_number=layer_number + softmax_scale, + attention_type=attention_type, + **attn_kwargs, + layer_number=layer_number, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -5236,6 +6343,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -5249,10 +6357,14 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. + cp_comm_type : str + inter-gpu communication type for context parallelism. + Can be "p2p" or "all_gather". """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream + self.cp_comm_type = cp_comm_type @no_torch_dynamo(recursive=False) def forward( @@ -5286,16 +6398,6 @@ def forward( Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`. - .. note:: - - Input tensor :attr:`query_layer` must be of shape - (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`, - :attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer` - must each be of shape (:attr:`sequence_length`, :attr:`batch_size`, - :attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape - (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads` - * :attr:`kv_channels`) is returned. - .. note:: DotProductAttention supports three backends: 1) FlashAttention which calls @@ -5423,7 +6525,7 @@ def forward( if self.fp8_meta["recipe"].fp8_mha: if not self.fp8_meta["recipe"].fp8_dpa: self.fp8_meta["recipe"].fp8_dpa = True - self.logger.WARNING( + self.logger.warning( """Forcing fp8_meta["recipe"].fp8_dpa=True due to """ """fp8_meta["recipe"].fp8_mha=True""" ) @@ -5445,7 +6547,17 @@ def forward( assert ( query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype ), "Queries, keys and values must have the same data type!" - assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" + assert ( + key_layer.shape[:-1] == value_layer.shape[:-1] + ), "Keys and values must have the same batch size, sequence length and number of heads!" + assert ( + key_layer.shape[-1] == self.hidden_size_per_attention_head_k + ), f"Keys have head_dim = {key_layer.shape[-1]}, " + "but expected head_dim = {self.hidden_size_per_attention_head_k}!" + assert ( + value_layer.shape[-1] == self.hidden_size_per_attention_head_v + ), f"Values have head_dim = {value_layer.shape[-1]}, " + "but expected head_dim = {self.hidden_size_per_attention_head_v}!" if attn_mask_type is None: attn_mask_type = self.attn_mask_type @@ -5479,6 +6591,11 @@ def forward( if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" + # convert causal to causal_bottom_right in inference when KV-caching is in use + # so users can run with the same attn_mask_type for training and inference + if attn_mask_type in ["causal", "padding_causal"]: + attn_mask_type = attn_mask_type + "_bottom_right" + if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) @@ -5539,13 +6656,22 @@ def forward( cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" if max_seqlen_q is None: - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item()))) + if cu_seqlens_q_padded is not None: + seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] + else: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64) if max_seqlen_kv is None: - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) + if cu_seqlens_kv_padded is not None: + seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] + else: + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) batch_size = len(cu_seqlens_q) - 1 + cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group) + context_parallel = cp_size > 1 + if qkv_format in ["sbhd", "bshd"]: assert all( len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) @@ -5556,6 +6682,8 @@ def forward( if qkv_format == "bshd": max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) batch_size = query_layer.shape[0] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size if cu_seqlens_q is not None: seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] assert all( @@ -5573,7 +6701,7 @@ def forward( assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - if max_seqlen_q == max_seqlen_kv: + if self.attention_type == "self": cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: @@ -5627,10 +6755,6 @@ def forward( _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - context_parallel = ( - self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1 - ) - core_attention_bias_shape = None if core_attention_bias is not None: if ( @@ -5671,7 +6795,8 @@ def forward( num_gqa_groups=key_layer.shape[-2], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, - head_dim=query_layer.shape[-1], + head_dim_qk=query_layer.shape[-1], + head_dim_v=value_layer.shape[-1], attn_mask_type=attn_mask_type, window_size=window_size, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, @@ -5740,6 +6865,7 @@ def forward( cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, + cp_comm_type=self.cp_comm_type, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) @@ -5782,6 +6908,7 @@ def forward( cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, + cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, ) @@ -5806,6 +6933,7 @@ def forward( cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, + cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, ) @@ -6234,6 +7362,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -6247,13 +7376,16 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. + cp_comm_type : str + inter-gpu communication type for context parallelism. + Can be "p2p" or "all_gather". """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: continue if hasattr(child, "set_context_parallel_group"): - child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream) + child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type) def forward( self, @@ -6269,6 +7401,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, ) -> Tuple[Union[torch.Tensor, None], ...]: """ @@ -6334,6 +7470,18 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + max_seqlen_q: Optional[int], default = `None` + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q` if not provided. + max_seqlen_kv: Optional[int], default = `None` + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. """ @@ -6360,6 +7508,9 @@ def forward( # ================================================= if inference_params and self.layer_number is not None: + assert ( + self.qkv_format != "thd" + ), "qkv_format == thd is not supported for an inference with KV-cache!" if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size @@ -6442,13 +7593,18 @@ def forward( dim=split_dim, ) - # query: -> [sq, b, np, hn] - # key, value: -> [sq, b, ng, hn] - query_layer, key_layer, value_layer = ( - x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) - for x in (query_layer, key_layer, value_layer) - ) - + if self.qkv_format == "thd": + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) + else: + # query: -> [sq, b, np, hn] + # key, value: -> [sq, b, ng, hn] + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) elif self.attention_type == "cross": # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] mixed_kv_layer = self.key_value( @@ -6562,8 +7718,10 @@ def forward( key_layer, value_layer, qkv_format=self.qkv_format, - cu_seqlens_q=None, - cu_seqlens_kv=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 4dc169da00..d0ba644621 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -140,7 +140,7 @@ def fused_attn_fwd_qkvpacked( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -342,7 +342,7 @@ def fused_attn_bwd_qkvpacked( output tensor, amax of dQKV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -508,7 +508,7 @@ def fused_attn_fwd_kvpacked( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -729,7 +729,7 @@ def fused_attn_bwd_kvpacked( output tensor, amax of dQKV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -907,7 +907,7 @@ def fused_attn_fwd( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -1135,7 +1135,7 @@ def fused_attn_bwd( output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 38392a5795..8502f70491 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Python interface for GEMM extensions""" +import functools from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex @@ -13,6 +14,12 @@ __all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] +@functools.lru_cache(maxsize=None) +def _empty_tensor() -> torch.Tensor: + """Get tensor with no entries and no data""" + return torch.Tensor() + + def fp8_gemm( A: torch.Tensor, A_scale_inv: torch.Tensor, @@ -39,7 +46,7 @@ def fp8_gemm( ) -> torch.Tensor: """TN layout GEMM with fp8 inputs.""" - empty_tensor = torch.Tensor() + empty_tensor = _empty_tensor() if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: assert fp8_meta_tensor is not None and out_index is not None assert_dim_for_fp8_exec(A) @@ -195,7 +202,7 @@ def gemm( assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" - empty_tensor = torch.Tensor() + empty_tensor = _empty_tensor() fp8_index = -1 # dummy index if out is None: @@ -313,8 +320,8 @@ def grouped_gemm( transa = layout[0] == "T" transb = layout[1] == "T" num_gemms = len(A) - empty_tensor = torch.Tensor() - empty_tensors = [torch.Tensor()] * num_gemms + empty_tensor = _empty_tensor() + empty_tensors = [empty_tensor] * num_gemms if gelu and not grad: gelu_input = [ @@ -401,8 +408,8 @@ def fp8_grouped_gemm( """ num_gemms = len(A) - empty_tensor = torch.Tensor() - empty_tensors = [torch.Tensor()] * num_gemms + empty_tensor = _empty_tensor() + empty_tensors = [empty_tensor] * num_gemms if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: assert fp8_meta_tensor is not None and out_offset is not None for a, b in zip(A, B): diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index de83bcd7f5..d96b743b9e 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """Python interface for transpose extensions""" -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import transformer_engine_torch as tex from ..constants import TE_DType @@ -13,6 +13,7 @@ "fp8_cast_transpose_fused", "fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_dgelu_fused", + "fp8_multi_cast_transpose_fused", "fp8_transpose_bgrad_fused", ] @@ -118,3 +119,25 @@ def fp8_cast_transpose_bgrad_dgelu_fused( amax_offset=int(fp8_tensor), scale_inv_offset=int(fp8_tensor), ) + + +def fp8_multi_cast_transpose_fused( + input_list: List[torch.Tensor], + fp8_meta_tensor: tex.FP8TensorMeta, + scale_indices: List[int], + amax_indices: List[int], + scale_inv_indices: List[int], + otype: tex.DType, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Cast + Transpose with FP8 output""" + + return tex.fused_multi_cast_transpose_alloc( + input_list, + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + scale_indices, + amax_indices, + scale_inv_indices, + otype, + ) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index b07c6d3508..4e9c74d396 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -274,7 +274,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): def __init__( self, num_offload_group, # must be <= actual number of groups (number of commits) - num_prefetch_group=1, + num_model_group, tensor_need_offloading_checker=(lambda t: True), debug=False, ) -> None: @@ -283,53 +283,29 @@ def __init__( tensor_need_offloading_checker=tensor_need_offloading_checker, debug=debug, ) - self.num_prefetch_group = num_prefetch_group - - # prepare for tensor buffer - self.tensor_id_to_tensor_buf_double_bufs = [] - for _ in range(2): - self.tensor_id_to_tensor_buf_double_bufs.append({}) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant # allocate streams and events for synchronization self.d2h_stream = torch.cuda.Stream() self.h2d_stream = torch.cuda.Stream() - self.h2d_finish_events = [] - self.compute_stream_bwd_start_events = [] - for _ in range(self.num_offload_group): - self.h2d_finish_events.append(torch.cuda.Event()) - self.compute_stream_bwd_start_events.append(torch.cuda.Event()) - self.d2h_final_event = torch.cuda.Event() - - def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag): - """Get tensor buffer for offloaded tensor.""" - group_id, tensor_id = tensor_tag - # obtain ping-pong buffer - id_buf_map = self.tensor_id_to_tensor_buf_double_bufs[(group_id % 2)] - - if not tensor_id in id_buf_map: - allocate_new_buf = True - else: - tensor_buf = id_buf_map[tensor_id] - allocate_new_buf = ( - tensor_buf.size() != tensor.size() or tensor_buf.dtype != tensor.dtype - ) - - if allocate_new_buf: - # supposed to only execute once - fp8_offload = isinstance(tensor, Float8Tensor) - buffer = torch.empty( - tensor.size(), - dtype=torch.uint8 if fp8_offload else tensor.dtype, - layout=tensor.layout, - device=tensor.device, - ) - - if isinstance(tensor, Float8Tensor): - id_buf_map[tensor_id] = Float8Tensor.make_like(tensor, data=buffer) - else: - id_buf_map[tensor_id] = buffer - - return id_buf_map[tensor_id] def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: @@ -347,21 +323,12 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: self.tensor_count_current_group += 1 assert tensor_tag not in self.tensor_tag_to_state + self.tensor_tag_to_state[tensor_tag] = tensor + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( tensor ): - # first copy the tensor to tensorbuf, - # so that the original tensor will not be deleted - tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) - tensor_buf.copy_(tensor) - if hasattr(tensor, "weight_offloading"): - tensor_buf.weight_offloading = True - if hasattr(tensor, "activation_offloading"): - tensor_buf.activation_offloading = True - # Here we just save it, and at commit, bulk_offload_group will handle it - self.tensor_tag_to_state[tensor_tag] = tensor_buf - else: - self.tensor_tag_to_state[tensor_tag] = tensor + self.tensor_tag_to_buf[tensor_tag] = tensor else: tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 @@ -373,6 +340,7 @@ def tensor_pop(self, tensor_tag, **kwargs): """Tensor pop.""" assert tensor_tag in self.tensor_tag_to_state tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) # the tensor should have been copied back in on_group_commit_backward() # which invokes bulk_reload_group. assert not isinstance(tensor, tuple) @@ -389,50 +357,49 @@ def bulk_offload_group(self, group_to_offload): # if offload, return the reference to cpu copy if self.tensor_need_offloading_checker(tensor_on_device): - if hasattr(tensor_on_device, "weight_offloading"): - delattr(tensor_on_device, "weight_offloading") - if hasattr(tensor_on_device, "activation_offloading"): - delattr(tensor_on_device, "activation_offloading") state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) self.tensor_tag_to_state[tensor_tag] = state def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" - # the host should wait for the copying of previous group - # to avoid overwriting buffer - previous_group = current_group - 1 - if previous_group < self.num_offload_group: - torch.cuda.synchronize() - # TODO (guyueh): this part is originally designed to reduce the peak memory usage. # pylint: disable=fixme - # however, uncommenting this part will cause illegal access, have not figured out why. - - if previous_group + 2 >= self.num_offload_group: - # this buffer is no longer required - self.tensor_id_to_tensor_buf_double_bufs[(previous_group % 2)] = {} - - # the copying of this group should wait for the computation stream event - if current_group < self.num_offload_group: - # perform bulk offloading + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.bulk_offload_group(current_group) - if current_group == self.num_offload_group - 1: - self.d2h_stream.record_event(self.d2h_final_event) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + + # Stream synchronization both ways + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 def on_group_commit_forward(self): """This function will cause host device synchronization""" # handle synchronization events self.synchronize_on_group_commit_forward(self.current_group) - # during forward, the next_group_to_fetch always points to the min of - # the last commited group, and the last offloaded group - self.next_group_to_fetch = min(self.current_group, self.num_offload_group - 1) - super().on_group_commit_forward() def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" assert group_to_reload < self.num_offload_group - if group_to_reload == self.num_offload_group - 1: - self.h2d_stream.wait_event(self.d2h_final_event) + with torch.cuda.stream(self.h2d_stream): # move back tensors for tensor_label, state in self.tensor_tag_to_state.items(): @@ -449,39 +416,29 @@ def on_group_commit_backward(self): self.current_group -= 1 assert self.current_group >= 0 - # decide the range of group to prefetch - should_prefetch_until_group = self.current_group - self.num_prefetch_group - should_prefetch_until_group = max(should_prefetch_until_group, 0) - - # do prefetch - for group_num_to_prefetch in range( - self.next_group_to_fetch, should_prefetch_until_group - 1, -1 - ): - # record the event in the compute stream, for h2d to wait - torch.cuda.current_stream().record_event( - self.compute_stream_bwd_start_events[group_num_to_prefetch] - ) - - # start of h2d should wait for the compute and the d2h - self.h2d_stream.wait_event(self.compute_stream_bwd_start_events[group_num_to_prefetch]) + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: - # recover tensors (copy back from host) - self.bulk_reload_group(group_num_to_prefetch) + # Stream synchronization both ways + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.h2d_stream) - # record an event for the backward of this layer to wait - self.h2d_stream.record_event(self.h2d_finish_events[group_num_to_prefetch]) + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) - # always is set to -1 at the end of the backward - self.next_group_to_fetch = min(self.num_offload_group - 1, should_prefetch_until_group - 1) + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 - # wait for the current group - if self.current_group < self.num_offload_group: - torch.cuda.current_stream().wait_event(self.h2d_finish_events[self.current_group]) + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 def get_cpu_offload_context( enabled: bool = False, num_layers: int = 1, + model_layers: int = 1, offload_activations: bool = True, offload_weights: bool = True, ): @@ -506,6 +463,8 @@ def get_cpu_offload_context( num_layers: int, default = 1 Determines the number of transformer layers you want to offload activations/weights for. + model_layers: int, default = 1 + Number of layers in the model that will be used under this context. offload_activations: bool, default = `True` When set to `True`, offloads the activations for the TE layer. offload_weights: bool, default = `True` @@ -537,7 +496,7 @@ def tensor_need_offloading_checker_all(tensor): cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( num_offload_group=num_layers, - num_prefetch_group=1, + num_model_group=model_layers, tensor_need_offloading_checker=tensor_need_offloading_checker, ) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 88609b6ddb..3b4e126943 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -166,7 +166,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { // Initialize userbuf communicator if (!comm_created) { if (myrank == 0) { - printf("!!! [UB] Create UbufCommOverlap Communicator\n"); + printf("!!! [UB] Create Userbuffers Communicator\n"); } #ifdef NVTE_UB_WITH_MPI create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); @@ -184,16 +184,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { // Allocate and register extra userbuffers int ubuf_bytes = sample.numel() * sample.element_size(); - if (transformer_engine::getenv("UB_SKIPMC")) { - _ubuf = torch::zeros_like(sample); - _ubuf_ptr = _ubuf.data_ptr(); - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, false); - } else { - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); - } + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, true); + _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); if (_ub_comm->myrank == 0) { printf("!!! [UB] Register UBuf %d\n", _ub_reg); @@ -264,6 +257,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -319,6 +313,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim1 = _ubuf.size(1); output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); + _ub_comm->sms = ori_sms; return {D, output_tensor}; } // bulk_overlap @@ -336,6 +331,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -352,7 +348,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); int *counter_ptr = reinterpret_cast(counter.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - int ori_sms = _ub_comm->sms; // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); @@ -388,7 +383,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_strided_atomic_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);); @@ -402,7 +397,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_strided_multiatomic_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);); @@ -413,10 +408,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } break; } else { + assert(_ubuf.element_size() != 1); consumer(counter_ptr, i, (cudaStream_t)_stream_comm); - // if (i == _num_splits-1) { - // _ub_comm->sms = UB_MAX_SM; - // } reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); } @@ -447,6 +440,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, at::Tensor rs_output) { // Get GEMM dimensions + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -464,7 +458,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - int ori_sms = _ub_comm->sms; // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); @@ -517,7 +510,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); @@ -541,7 +534,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); @@ -577,7 +570,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reducescatter2_userbuff_stridedoutput_fp8( rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); @@ -682,7 +675,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Initialize userbuf communicator if (!comm_created) { if (myrank == 0) { - printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n"); + printf("!!! [UB] Create Userbuffers Communicator\n"); } #ifdef NVTE_UB_WITH_MPI create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); @@ -708,19 +701,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ubuf_bytes = static_cast(ubuf_bytes / tp_size * (tp_size * 2 - 1)); num_ubuf_chunks = static_cast(tp_size * 2 - 1); } - if (transformer_engine::getenv("UB_SKIPMC")) { - _ubuf = torch::zeros({sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, - sample.options()); - _ubuf_ptr = _ubuf.data_ptr(); - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, false); - } else { - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = - torch::from_blob(_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, - sample.options()); - } + + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, true); + _ubuf = torch::from_blob( + _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); if (_ub_comm->myrank == 0) { printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); } @@ -728,9 +713,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); for (int i = 0; i < num_ubuf_chunks; i++) { - torch::Tensor ubuf_chunk = torch::from_blob( - ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options()); - _ubufs.push_back(ubuf_chunk); + auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, + sample.options()); + _ubufs.push_back(std::move(ubuf_chunk)); ubuf_byte_ptr += ubuf_chunk_bytes; } @@ -769,6 +754,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); if (_rank == 0 && env_p != nullptr) { if (env_p[0] == '1') { + _use_ce = 0; + _ub_comm->push = 1; printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); } } @@ -818,6 +805,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -866,6 +854,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); if (env_p != nullptr && env_p[0] == '1') { if (i == 0) { + _ub_comm->use_ce = 0; userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, true, (cudaStream_t)_stream_recv); @@ -906,6 +895,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); // Return the last N rows of D_buffer + _ub_comm->sms = ori_sms; torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); return D_return; } // atomic_gemm_overlap_ag @@ -926,6 +916,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -1078,6 +1069,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); at::cuda::setCurrentCUDAStream(stream_main); + _ub_comm->sms = ori_sms; return D; } // split_overlap_ag @@ -1094,6 +1086,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -1149,7 +1142,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main);); } else { @@ -1157,6 +1150,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); torch::sum_out(rs_output, reduce_buf, 0); } + _ub_comm->sms = ori_sms; } /* @@ -1171,6 +1165,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; @@ -1210,11 +1205,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}, workspace.options()); - if (i == _tp_size - 1) { - at::cuda::setCurrentCUDAStream(stream_main); - } else { - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - } + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); @@ -1235,6 +1226,13 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { recv_rank, (cudaStream_t)_stream_recv); } } + at::cuda::setCurrentCUDAStream(stream_main); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); @@ -1245,7 +1243,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - B_type, fp8_type, + D_type, fp8_type, reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main);); } else { @@ -1253,12 +1251,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); torch::sum_out(rs_output, reduce_buf, 0); } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + _ub_comm->sms = ori_sms; } /* diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f06b0cb197..05e4e97112 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -14,11 +14,14 @@ * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right); +NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float p_dropout, + size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); std::vector fused_attn_fwd_qkvpacked( size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, @@ -177,6 +180,11 @@ void fused_multi_cast_transpose(std::vector input_list, std::vector scale_inv_output_list, transformer_engine::DType otype); +std::tuple, std::vector> fused_multi_cast_transpose_alloc( + std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + std::vector scale_indices, std::vector amax_indices, + std::vector scale_inv_indices, transformer_engine::DType otype); + at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype); void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype); @@ -415,12 +423,19 @@ std::tuple multi_tensor_unscale_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor inv_scale, at::optional per_tensor_python); +using transformer_engine::DType; void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay); +void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype); + void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor lr, const float beta1, const float beta2, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index af84054b4c..50eb7b830f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -14,11 +14,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right) { + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, window_size_left, window_size_right); + head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; } @@ -127,6 +128,9 @@ std::vector fused_attn_fwd_qkvpacked( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + O.fill_(0); + } // BF16 or FP16 te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); @@ -288,6 +292,9 @@ std::vector fused_attn_bwd_qkvpacked( amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dQKV.fill_(0); + } // BF16 or FP16 te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); @@ -328,6 +335,9 @@ std::vector fused_attn_bwd_qkvpacked( options); te_dBias = makeTransformerEngineTensor(dBias); } + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dBias.fill_(0); + } } // create cu_seqlens tensorwrappers @@ -427,6 +437,9 @@ std::vector fused_attn_fwd_kvpacked( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + O.fill_(0); + } // BF16 or FP16 te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_KV = @@ -614,6 +627,10 @@ std::vector fused_attn_bwd_kvpacked( amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dQ.fill_(0); + dKV.fill_(0); + } // BF16 or FP16 te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_KV = @@ -684,6 +701,9 @@ std::vector fused_attn_bwd_kvpacked( options); te_dBias = makeTransformerEngineTensor(dBias); } + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dBias.fill_(0); + } } // create workspace @@ -742,7 +762,11 @@ std::vector fused_attn_fwd( std::vector v_shape{v_sizes.begin(), v_sizes.end()}; // create output tensor O - auto O = torch::empty_like(Q); + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto o_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + o_shape[o_shape.size() - 1] = v_sizes[v_sizes.size() - 1]; + std::vector o_shape_tmp{o_shape.begin(), o_shape.end()}; + auto O = torch::empty(c10::IntArrayRef(o_shape_tmp), options); // construct NVTE tensors TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; @@ -771,15 +795,18 @@ std::vector fused_attn_fwd( descale_QKV.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + O.fill_(0); + } // BF16 or FP16 te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -817,8 +844,7 @@ std::vector fused_attn_fwd( auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); @@ -913,8 +939,11 @@ std::vector fused_attn_bwd( std::vector v_shape{v_sizes.begin(), v_sizes.end()}; auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; + auto d_qk = q_shape[q_shape.size() - 1]; + auto d_v = v_shape[v_shape.size() - 1]; auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + std::vector o_shape{q_sizes.begin(), q_sizes.end()}; + o_shape[o_shape.size() - 1] = d_v; at::Tensor dQ; at::Tensor dK; @@ -993,7 +1022,7 @@ std::vector fused_attn_bwd( TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && + if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -1019,9 +1048,9 @@ std::vector fused_attn_bwd( descale_QKV.value().data_ptr()); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, + te_dO = makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); @@ -1037,13 +1066,18 @@ std::vector fused_attn_bwd( makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } // BF16 or FP16 te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dQ = @@ -1109,6 +1143,9 @@ std::vector fused_attn_bwd( options); te_dBias = makeTransformerEngineTensor(dBias); } + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dBias.fill_(0); + } } // create workspace @@ -1535,7 +1572,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float dtype *p_per_step = reinterpret_cast(&data_per_step); dtype *p = reinterpret_cast(&data); for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { - p[k] += p_per_step[k] * lse_corrected_exp; + p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); } reinterpret_cast(cur_out)[j] = data; } diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index 2752f92348..09b53a8976 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -8,16 +8,19 @@ #include #include #include +#include // Another possibility: // #include #include +#include "common/utils.cuh" #include "multi_tensor_apply.cuh" #include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 +#define THREADS_PER_WARP 32 typedef enum { ADAM_MODE_0 = 0, // L2 regularization mode @@ -25,6 +28,156 @@ typedef enum { } adamMode_t; using MATH_T = float; +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; +using transformer_engine::DType; + +template +struct is_fp8 : std::false_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template +struct FP8Data { + float scale; + float *amax_ptr; + float *scale_inv_ptr; + float max; + int warp_id; +}; + +template <> +struct FP8Data {}; + +template +struct AdamFunctorMaster { + static constexpr bool is_fp8_type = is_fp8::value; + + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, + TensorListMetadata<5, is_fp8_type> &tl, // NOLINT(*) + const float beta1, const float beta2, + const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + FP8Data fp8_data; + + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; + + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); + p_master += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + if constexpr (is_fp8_type) { + float *scale_ptr = reinterpret_cast(tl.fp8_meta_addresses[0][tensor_loc]); + fp8_data.scale = scale_ptr != nullptr ? *scale_ptr : 1; + fp8_data.amax_ptr = reinterpret_cast(tl.fp8_meta_addresses[1][tensor_loc]); + fp8_data.scale_inv_ptr = reinterpret_cast(tl.fp8_meta_addresses[2][tensor_loc]); + fp8_data.warp_id = threadIdx.x / THREADS_PER_WARP; + fp8_data.max = 0; + } + + // see note in multi_tensor_scale_kernel.cu + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]); + r_p[ii] = static_cast(p_master[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p_master[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + if constexpr (is_fp8_type) { + __builtin_assume(fp8_data.max >= 0); + fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max); + p[i] = static_cast(r_p[ii] * fp8_data.scale); + } else { + p[i] = static_cast(r_p[ii]); + } + } + } + } + + if constexpr (is_fp8_type) { + fp8_data.max = transformer_engine::reduce_max( + fp8_data.max, fp8_data.warp_id); + if (threadIdx.x == 0) { + if (fp8_data.amax_ptr != nullptr) { + transformer_engine::atomicMaxFloat(fp8_data.amax_ptr, fp8_data.max); + } + if (fp8_data.scale_inv_ptr != nullptr) { + *fp8_data.scale_inv_ptr = __frcp_rn(fp8_data.scale); + } + } + } + } +}; template struct AdamFunctor { @@ -338,22 +491,114 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, } } + const auto p_in_type = tensor_lists[1][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 4: g, p, m, v + // case 5: g, p, m, v, p_master + TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5"); + + if (requires_64bit_indexing) { + if (tl_size == 4) { + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, + weight_decay);) + } else { + // g, p, m, v, p_master + const auto g_in_type = tensor_lists[0][0].scalar_type(); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + } + } else { + if (tl_size == 4) { + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, + weight_decay);) + } else { + const auto g_in_type = tensor_lists[0][0].scalar_type(); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + } + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { + break; + } + } + + const auto g_in_type = tensor_lists[0][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv + TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors"); + if (requires_64bit_indexing) { - // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( - tensor_lists[0][0].scalar_type(), 0, "adam", - multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + fp8_dtype, FP8_T, + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 0, "adam", + multi_tensor_apply<5, true>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } else { - // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( - tensor_lists[0][0].scalar_type(), 0, "adam", - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + fp8_dtype, FP8_T, + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 0, "adam", + multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, + lr, (adamMode_t)mode, weight_decay);)); } AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 89bce77ded..11b47ccdec 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -84,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); + m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, + "Fused Multi-tensor Cast + Transpose with allocating output tensors", + py::call_guard()); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard()); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", py::call_guard()); @@ -188,6 +191,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_adam", &multi_tensor_adam_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); + m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda, + "Compute and apply gradient update to parameters for Adam optimizer", + py::call_guard()); m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "support and LR scheduling", diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 473954d099..56f6b56769 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -75,7 +75,7 @@ std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::T // Return immediately if tensors are empty if (M == 0 || N == 0) { - return {grad_bias, grad_output_cast, grad_output_transpose}; + return {grad_bias.zero_(), grad_output_cast, grad_output_transpose}; } // Get pointers for FP8 scale, amax, scale-inverse @@ -196,22 +196,21 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_list, - std::vector scale_inv_list, - transformer_engine::DType otype) { +void fused_multi_cast_transpose_base(std::vector input_list, + std::vector scale_dptr_list, + std::vector cast_output_list, + std::vector transposed_output_list, + std::vector amax_dptr_list, + std::vector scale_inv_dptr_list, + transformer_engine::DType otype) { using namespace transformer_engine; // Extract properties from PyTorch tensors - std::vector input_dptr_list, scale_dptr_list, cast_output_dptr_list, - transposed_output_dptr_list, amax_dptr_list, scale_inv_dptr_list; - std::vector> input_shape_list, scale_shape_list, cast_output_shape_list, - transposed_output_shape_list, amax_shape_list, scale_inv_shape_list; - std::vector input_type_list, scale_type_list, cast_output_type_list, - transposed_output_type_list, amax_type_list, scale_inv_type_list; + std::vector input_dptr_list, cast_output_dptr_list, transposed_output_dptr_list; + std::vector> input_shape_list, cast_output_shape_list, + transposed_output_shape_list; + std::vector input_type_list, cast_output_type_list, + transposed_output_type_list; auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector& dptr_list, std::vector>& shape_list) { dptr_list.push_back(tensor.data_ptr()); @@ -232,20 +231,14 @@ void fused_multi_cast_transpose(std::vector input_list, }; for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list); - extract_tensor_props(scale_list[tensor_id], scale_dptr_list, scale_shape_list, scale_type_list); extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list, cast_output_shape_list); cast_output_type_list.push_back(otype); extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list, transposed_output_shape_list); transposed_output_type_list.push_back(otype); - extract_tensor_props(amax_list[tensor_id], amax_dptr_list, amax_shape_list, amax_type_list); - extract_tensor_props(scale_inv_list[tensor_id], scale_inv_dptr_list, scale_inv_shape_list, - scale_inv_type_list); } - transformer_engine::TensorWrapper workspace; - // Construct TE tensors std::vector nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list; std::vector tensor_wrappers; @@ -257,6 +250,7 @@ void fused_multi_cast_transpose(std::vector input_list, return tensor_wrappers.back().data(); }; for (size_t i = 0; i < input_dptr_list.size(); ++i) { + if (input_dptr_list[i] == nullptr) continue; nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i], nullptr, nullptr, nullptr)); nvte_cast_output_list.emplace_back( @@ -280,6 +274,55 @@ void fused_multi_cast_transpose(std::vector input_list, at::cuda::getCurrentCUDAStream()); } +void fused_multi_cast_transpose(std::vector input_list, + std::vector scale_list, + std::vector cast_output_list, + std::vector transposed_output_list, + std::vector amax_list, + std::vector scale_inv_list, + transformer_engine::DType otype) { + std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; + for (size_t i = 0; i < scale_list.size(); ++i) { + scale_dptr_list.push_back(scale_list[i].data_ptr()); + amax_dptr_list.push_back(amax_list[i].data_ptr()); + scale_inv_dptr_list.push_back(scale_inv_list[i].data_ptr()); + } + + fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, + transposed_output_list, amax_dptr_list, scale_inv_dptr_list, + otype); +} + +std::tuple, std::vector> fused_multi_cast_transpose_alloc( + std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, + std::vector scale_indices, std::vector amax_indices, + std::vector scale_inv_indices, transformer_engine::DType otype) { + using namespace transformer_engine; + + std::vector cast_output_list; + std::vector transposed_output_list; + std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; + for (size_t i = 0; i < input_list.size(); ++i) { + auto input_i = input_list[i]; + // construct cast output tensors + auto cast_output_i = allocateTorchTensor(input_i.size(0), input_i.size(1), DType::kByte); + cast_output_list.push_back(cast_output_i); + // construct transposed output tensors + auto transposed_output_i = allocateTorchTensor(input_i.size(1), input_i.size(0), DType::kByte); + transposed_output_list.push_back(transposed_output_i); + // construct amax/scale/scale_inv dptr lists + amax_dptr_list.push_back(getDataPtr(amax, amax_indices[i])); + scale_dptr_list.push_back(getDataPtr(scale, scale_indices[i])); + scale_inv_dptr_list.push_back(getDataPtr(scale_inv, scale_inv_indices[i])); + } + + fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, + transposed_output_list, amax_dptr_list, scale_inv_dptr_list, + otype); + + return std::make_tuple(std::move(cast_output_list), std::move(transposed_output_list)); +} + at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) { using namespace transformer_engine; diff --git a/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh b/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh index 4996dfd05e..e85ec3afc2 100644 --- a/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh +++ b/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh @@ -12,38 +12,55 @@ #include #include +#include "common/common.h" + // This header is the one-stop shop for all your multi-tensor apply needs. // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; -template -struct TensorListMetadata { +template +struct TensorListMetadataBase { void *addresses[n][depth_to_max_tensors[n - 1]]; int sizes[depth_to_max_tensors[n - 1]]; unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int block_to_chunk[depth_to_max_blocks[n - 1]]; int start_tensor_this_launch; }; +template +struct TensorListMetadata : public TensorListMetadataBase {}; + +template +struct TensorListMetadata : public TensorListMetadataBase { + void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]]; +}; + template __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl, U callable, ArgTypes... args) { - // Hand the chunk information to the user-supplied functor to process however it likes. + // Hand the chunk information to the user-supplied functor to process however + // it likes. callable(chunk_size, noop_flag, tl, args...); } -template +template void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag, const std::vector> &tensor_lists, T callable, ArgTypes... args) { - TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + if constexpr (USE_FP8) { + TORCH_CHECK(tensor_lists.size() == depth + 3, + "tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, " + "amax, scale_inv) for fp8"); + } else { + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + } int len0 = tensor_lists[0].size(); TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); auto ref_device = tensor_lists[0][0].device(); TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); - for (int l = 0; l < tensor_lists.size(); l++) { // No range-based for because I need indices + for (int l = 0; l < depth; l++) { // No range-based for because I need indices TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); for (int t = 0; t < tensor_lists[l].size(); t++) { // TODO: Print which tensor fails. @@ -58,9 +75,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor } } + if constexpr (USE_FP8) { + TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0, + "Size mismatch among tensor lists"); + } + int ntensors = tensor_lists[0].size(); - TensorListMetadata tl; + TensorListMetadata tl; const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); auto stream = at::cuda::getCurrentCUDAStream(); @@ -72,12 +94,15 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); for (int d = 0; d < depth; d++) tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + if constexpr (USE_FP8) { + for (int i = 0; i < 3; i++) + tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr(); + } loc_tensor_info++; auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { - // std::cout << chunks_this_tensor << std::endl; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_chunk[loc_block_info] = chunk; loc_block_info++; @@ -87,7 +112,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); if (tensors_full || blocks_full || last_chunk) { - // using accscalar_t = acc_type; multi_tensor_apply_kernel<<>>( chunk_size, noop_flag.data_ptr(), tl, callable, args...); @@ -100,7 +124,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor tl.start_tensor_this_launch = t + 1; } else { tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; - for (int d = 0; d < depth; d++) tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + } + if constexpr (USE_FP8) { + for (int i = 0; i < 3; i++) { + tl.fp8_meta_addresses[i][0] = tl.fp8_meta_addresses[i][loc_tensor_info - 1]; + } + } loc_tensor_info = 1; tl.start_tensor_this_launch = t; } diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index e1bcfecc13..8515092ae0 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -258,7 +258,8 @@ at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_te // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs - const int sm_count = transformer_engine::cuda::sm_count(); + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -293,7 +294,8 @@ std::vector te_grouped_gemm_ts( // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs - const int sm_count = transformer_engine::cuda::sm_count(); + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse, diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 03a1a6a3df..0cd2a0253b 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -1861,6 +1861,14 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) } +template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + +template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index a6f62ac457..e2642bc360 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -3,11 +3,14 @@ # See LICENSE for license information. """Functions for CUDA Graphs support in FP8""" +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union + import torch from torch.utils._pytree import tree_flatten as _tree_flatten from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch._C import _graph_pool_handle +from transformer_engine.common.recipe import DelayedScaling from .fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -22,6 +25,9 @@ _IS_GRAPH_CAPTURING = False +_T = TypeVar("_T") +SingleOrTuple = Union[_T, Tuple[_T, ...]] + def set_capture_start() -> None: """Record beginning of `make_graphed_callables`.""" @@ -48,13 +54,14 @@ def graph_pool_handle(): def _make_graphed_callables( - callables, - sample_args, - num_warmup_iters=3, - allow_unused_input=False, - fp8_weight_caching=False, - _order=None, -): + callables: SingleOrTuple[Callable], + sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + fp8_weight_caching: bool = False, + sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, + _order: Optional[List[int]] = None, +) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` """ @@ -65,16 +72,38 @@ def _make_graphed_callables( "caching. Please set `cache_enabled=False`." ) - just_one_callable = False + # Default is to pass no kwargs to callables + if sample_kwargs is None: + if isinstance(callables, tuple): + sample_kwargs = tuple({} for _ in range(len(sample_args))) + else: + sample_kwargs = {} + # Canonicalize args as tuples + just_one_callable = False if not isinstance(callables, tuple): just_one_callable = True callables = (callables,) sample_args = (sample_args,) + sample_kwargs = (sample_kwargs,) - flatten_sample_args = [] - if _order is not None: - # order is a list containing 1..model_chunk values in the order of microbatch schedule + # Check sizes of args + if _order is None: + assert len(sample_args) == len(callables) + assert len(sample_kwargs) == len(callables) + else: + # Custom logic for interleaved pipeline parallelism + # Note: This is tightly coupled with the Megatron-core + # implementation of interleaved pipeline parallelism at + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py. + # Note: The model is assumed to consist of layers + # (corresponding to callables) that are grouped into + # equally-sized model chunks. _order is a list of chunk + # indices (1-indexed) that indicates the order in which the + # layers are evaluated. Positive values indicate forward + # passes and negative values indicate backward passes. Each + # entry in sample_args corresponds to one of the forward + # passes. num_model_chunks = max(_order) num_microbatches = len(_order) // num_model_chunks // 2 assert num_model_chunks * num_microbatches * 2 == len(_order) @@ -90,10 +119,13 @@ def _make_graphed_callables( f"Expected {num_model_chunks * num_microbatches}" + f"args tuple, but got {len(sample_args)}." ) + assert len(sample_kwargs) == len(sample_args) if fp8_weight_caching: + # Initialize flag that controls FP8 weight updates FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + # Check callables for c in callables: if isinstance(c, torch.nn.Module): assert ( @@ -110,9 +142,14 @@ def _make_graphed_callables( + ":func:`~make_graphed_callables`, only parameters may be trainable. " + "All buffers must have ``requires_grad=False``." ) - for args in sample_args: + + # Flatten callable arguments + per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs] + flatten_sample_args = [] + for args, kwargs, kwargs_keys in zip(sample_args, sample_kwargs, per_callable_kwargs_keys): flatten_arg, _ = _tree_flatten(args) - flatten_sample_args.append(tuple(flatten_arg)) + flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys]) + flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg)) assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( "In the beta API, sample_args " + "for each callable must contain only Tensors. Other types are not allowed." @@ -120,6 +157,10 @@ def _make_graphed_callables( # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # passes to forward (ie, its sample_args) AND the module's parameter attributes. + # Note: These per_callable_* variables are not actually + # per-callable, but per-forward-pass (see description of _order). + # The names are kept for consistency with + # torch.cuda.make_graphed_callables. per_callable_len_user_args = [len(args) for args in flatten_sample_args] if _order is None: per_callable_module_params = [ @@ -144,6 +185,7 @@ def _make_graphed_callables( fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))] + # For cases with multiple active RNG states, e.g. TP. if graph_safe_rng_available(): for _, state in get_all_rng_states().items(): @@ -158,11 +200,12 @@ def _make_graphed_callables( # from ending up in any captures. torch.cuda.synchronize() with torch.cuda.stream(torch.cuda.Stream()): - for c_i, func in enumerate(callables): - args = sample_args[c_i] - static_input_surface = per_callable_static_input_surfaces[c_i] + for func_idx, func in enumerate(callables): + args = sample_args[func_idx] + kwargs = sample_kwargs[func_idx] + static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): - outputs, _ = _tree_flatten(func(*args)) + outputs, _ = _tree_flatten(func(*args, **kwargs)) grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -170,7 +213,7 @@ def _make_graphed_callables( only_inputs=True, allow_unused=allow_unused_input, ) - del outputs, grad_inputs + del outputs, grad_inputs torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -194,9 +237,10 @@ def _make_graphed_callables( fwd_idx[m_chunk] * num_layers + l_no ) args = sample_args[per_callable_fwd_idx] + kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + outputs = func(*args, **kwargs) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec @@ -245,9 +289,9 @@ def _make_graphed_callables( per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] graph_id = 0 - for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + outputs = func(*args, **kwargs) graph_callables[graph_id] = func graph_id += 1 @@ -300,6 +344,7 @@ def make_graphed_autograd_function( fwd_graph, bwd_graph, module_params, + kwargs_keys, len_user_args, output_unflatten_spec, static_input_surface, @@ -312,14 +357,18 @@ class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, skip_fp8_weight_update, *inputs): - # At this stage, only the user args may (potentially) be new tensors. + + # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) + # Copy values from new tensors into static tensors for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): static_input_surface[i].copy_(inputs[i]) + + # Replay forward graph fwd_graph.replay() assert isinstance(static_outputs, tuple) return tuple(o.detach() for o in static_outputs) @@ -327,6 +376,8 @@ def forward(ctx, skip_fp8_weight_update, *inputs): @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, *grads): + + # Replay backward graph assert len(grads) == len(static_grad_outputs) for g, grad in zip(static_grad_outputs, grads): if g is not None: @@ -336,6 +387,7 @@ def backward(ctx, *grads): g.copy_(grad) bwd_graph.replay() + # Update FP8 scale factors if needed if ctx.is_first_module: FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -346,10 +398,8 @@ def backward(ctx, *grads): ) def functionalized(*user_args, **user_kwargs): - # Runs the autograd function with inputs == all - # inputs to the graph that might require grad - # (explicit user args + module parameters) - # Assumes module params didn't change since capture. + + # Decide whether to update FP8 weights skip_fp8_weight_update = None if fp8_weight_caching: assert "is_first_microbatch" in user_kwargs and isinstance( @@ -358,8 +408,22 @@ def functionalized(*user_args, **user_kwargs): skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + # Check that required kwargs are provided + for key in kwargs_keys: + if key not in user_kwargs: + raise TypeError( + f"Graphed callable was initialized with kwarg {key} ," + "but it was not provided in graph replay" + ) + + # Runs the autograd function with inputs == all inputs to + # the graph that might require grad (explicit user args + + # module parameters) + # Assumes module params didn't change since capture. flatten_user_args, _ = _tree_flatten(user_args) - out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params)) + flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) + func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params + out = Graphed.apply(skip_fp8_weight_update, *func_args) return _tree_unflatten(out, output_unflatten_spec) return functionalized @@ -371,6 +435,7 @@ def functionalized(*user_args, **user_kwargs): fwd_graphs[i], bwd_graphs[i], per_callable_module_params[i], + per_callable_kwargs_keys[i], per_callable_len_user_args[i], per_callable_output_unflatten_spec[i], per_callable_static_input_surfaces[i], @@ -443,25 +508,42 @@ def restore_fp8_tensors(modules, fp8_tensors): def make_graphed_callables( - modules, - sample_args, - num_warmup_iters=3, - allow_unused_input=False, - fp8_enabled=False, - fp8_calibrating=False, - fp8_recipe=None, - fp8_weight_caching=False, - _order=None, -): + modules: SingleOrTuple[Callable], + sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, + fp8_enabled: bool = False, + fp8_calibrating: bool = False, + fp8_recipe: Optional[DelayedScaling] = None, + fp8_weight_caching: bool = False, + _order: Optional[List[int]] = None, +) -> Union[Callable, Tuple[Callable, ...]]: """ - A version of PyTorch's `make_graphed_callables` utility function with support for - TransformerEngine modules and FP8. Please see the original version in upstream PyTorch - `here `_ - for extensive documentation. The documentation for additional parameters which are - specific to FP8 are given below. - - FP8 specific parameters - ----------------------- + Make CUDA graph version of Transformer Engine modules + + A variation of PyTorch's `make_graphed_callables` utility function + with support for Transformer Engine modules and FP8. Please see + the + `original PyTorch implementation `_ + for more documentation. + + Graphing parameters + ------------------- + modules: (tuple of) callable + Callable or callables to graph. + sample_args: (tuple of) tuple of torch.Tensor + Positional arguments to callable(s). + num_warmup_iters: int, default = 3 + Number of warmup iterations. + allow_unused_input: bool, default = `False` + Whether to handle case where callable inputs + and outputs are disconnected in compute graph. + sample_kwargs: (tuple of) dict, optional + Keyword arguments to callable(s) + + FP8-related parameters + ---------------------- fp8_enabled: bool, default = `True` whether or not to enable fp8 fp8_calibrating: bool, default = `False` @@ -478,6 +560,7 @@ def make_graphed_callables( using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg must be set to `False` if calculating weight transposes' outside TE, e.g., in the optimizer step. + """ set_capture_start() @@ -532,6 +615,7 @@ def forward_func(*args, **kwargs): num_warmup_iters=num_warmup_iters, allow_unused_input=allow_unused_input, fp8_weight_caching=fp8_weight_caching, + sample_kwargs=sample_kwargs, _order=_order, ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cbcda20fe8..3613e1fa5e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -107,7 +107,7 @@ def initialize_ub( world_size = torch.distributed.get_world_size(mpi_group) local_rank = world_rank % tp_size local_size = tp_size - node_id = world_rank // tp_size + self_node_idx = world_rank // tp_size num_nodes = world_size // tp_size ub_callbacks = tex.UbufBootstrapCallbacks() else: @@ -127,13 +127,6 @@ def initialize_ub( world_rank = torch.distributed.get_rank(world_group) world_size = torch.distributed.get_world_size(world_group) - if world_rank == 0: - print( - f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n', - end="", - flush=True, - ) - # Construct an intra-node communicator based on global ranks that share the same hostname # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host # address on that interface instead of the hostname. This can help avoid issues when @@ -157,28 +150,41 @@ def initialize_ub( hostnames = [None for _ in range(world_size)] torch.distributed.all_gather_object(hostnames, hostname, world_group) - intra_node_ranks = [] - for i, host in enumerate(hostnames): - if host == hostname: - intra_node_ranks.append(i) - if len(intra_node_ranks) == world_size: + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + num_nodes = len(unique_hosts) + + if num_nodes > 1: + ranks_per_node_list = [[] for _ in range(num_nodes)] + self_node_idx = -1 + for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) + if host == hostname: + self_node_idx = node_idx + assert self_node_idx >= 0, "Internal TE error!" + + intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_node_list, backend=bootstrap_backend + ) + local_rank = torch.distributed.get_rank(intra_node_group) + local_size = torch.distributed.get_world_size(intra_node_group) + intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group) + + else: + self_node_idx = 0 intra_node_group = world_group local_rank = world_rank local_size = world_size intra_node_ranks = list(range(world_size)) - else: - intra_node_group = torch.distributed.new_group( - backend=bootstrap_backend, ranks=intra_node_ranks - ) - local_rank = torch.distributed.get_rank(intra_node_group) - local_size = torch.distributed.get_world_size(intra_node_group) - node_id = world_rank // local_size - num_nodes = world_size // local_size + if world_rank == 0: + print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [NVTE] Number of physical nodes: {num_nodes}\n" - + f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n", + f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n", end="", flush=True, ) @@ -293,7 +299,7 @@ def add_ub( world_size, # World size local_rank, # Rank within the node local_size, # Number of ranks/GPUs per node - node_id, # Node ID + self_node_idx, # Node ID num_nodes, # Number of nodes tp_size, # Tensor-parallel group size (may be different than local_size) num_sm, # Number of communication SMs @@ -313,7 +319,7 @@ def add_ub( world_size, # World size local_rank, # Rank within the node local_size, # Number of ranks/GPUs per node - node_id, # Node ID + self_node_idx, # Node ID num_nodes, # Number of nodes tp_size, # Tensor-parallel group size (may be different than local_size) num_sm, # Number of communication SMs @@ -334,7 +340,9 @@ def add_ub( layers_reduce_scatter_overlap.remove(wgrad_name) layers_all_gather_overlap.remove(name) layers_reduce_scatter_overlap.append(name) - methods["pipeline"].append(name) + methods["bulk"].remove(name) + new_method = ub_cfgs[name]["method"] + methods[new_method].append(name) for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: ub_cfg = get_default_config(name) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 352ce1ecbb..a91ff5c361 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,8 +3,6 @@ # See LICENSE for license information. """GroupedLinear API""" -import os -import logging from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch @@ -36,7 +34,7 @@ from ..cpp_extensions import ( cast_to_fp8, fp8_cast_transpose_bgrad_fused, - fp8_cast_transpose_fused, + fp8_multi_cast_transpose_fused, fp8_grouped_gemm, grouped_gemm, ) @@ -45,17 +43,6 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["GroupedLinear"] """ @@ -95,13 +82,12 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, + weights_fp8: List[Union[Float8Tensor, None]], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], ) -> torch.Tensor: - logger = logging.getLogger("GroupedLinear") num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] - weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms] - biases = weights_and_biases[2 * num_gemms :] + biases = weights_and_biases[num_gemms:] # Make sure input dimensions are compatible in_features = weights[0].shape[-1] @@ -127,15 +113,15 @@ def forward( and not sequence_parallel ): # FP8 input for forward, FP8 input transpose for backward wgrad - for i in range(num_gemms): - mat, mat_t = fp8_cast_transpose_fused( - inputmats_no_fp8[i], - fp8_meta["scaling_fwd"], - _GEMM_INPUT + i, - fp8_dtype_forward, - ) - inputmats.append(mat) - inputmats_t.append(mat_t) + indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms)) + inputmats, inputmats_t = fp8_multi_cast_transpose_fused( + inputmats_no_fp8, + fp8_meta["scaling_fwd"], + indices, # scale_indices + indices, # amax_indices + indices, # scale_inv_indices + fp8_dtype_forward, + ) else: # FP8 input for forward inputmats = [ @@ -151,8 +137,6 @@ def forward( inputmats = inputmats_no_fp8 if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases @@ -184,8 +168,6 @@ def forward( use_split_accumulator=_2X_ACC_FPROP, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weights = [cast_if_needed(w, activation_dtype) for w in weights] biases = ( @@ -237,9 +219,6 @@ def forward( saved_inputmats = inputmats_no_fp8 if cpu_offloading: - if fuse_wgrad_accumulation: - for w in weights: - w.main_grad.weight_offloading = True if fp8: for w in weights_fp8: if w is not None: @@ -289,8 +268,6 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("GroupedLinear") - with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( fwd_scale_inverses, @@ -303,7 +280,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grads = saved_tensors[4 * ctx.num_gemms :] if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: for i in ctx.num_gemms: - w = torch.nn.Parameter(weights[i], False) + w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w.main_grad = main_grads[i] weights[i] = w @@ -331,13 +308,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - for i in range(ctx.num_gemms): - grad_output_c[i], grad_output_t[i] = fp8_cast_transpose_fused( - grad_output_mats[i], - ctx.fp8_meta["scaling_bwd"], - _GRAD_OUTPUT + i, - fp8_dtype_backward, - ) + indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms)) + grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( + grad_output_mats, + ctx.fp8_meta["scaling_bwd"], + indices, # scale_indices + indices, # amax_indices + indices, # scale_inv_indices + fp8_dtype_backward, + ) else: for i in range(ctx.num_gemms): grad_output_c[i] = cast_to_fp8( @@ -356,9 +335,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - logger.debug("Running backward in FP8") dgrad = torch.empty( - (sum(ctx.m_splits), weights_fp8[i].size(1)), + (sum(ctx.m_splits), weights_fp8[0].size(1)), dtype=ctx.activation_dtype, device=grad_output.device, ) @@ -379,8 +357,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=_2X_ACC_DGRAD, ) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - dgrad = torch.empty( (sum(ctx.m_splits), weights[0].size(1)), dtype=ctx.activation_dtype, @@ -513,8 +489,8 @@ def handle_custom_ddp_from_mcore(w, wgrad): None, # activation_dtype None, # parallel_mode None, # is_grad_enabled + None, # weights_fp8 *wgrad_list, - *([None] * ctx.num_gemms), # weights_fp8 *grad_biases, ) @@ -825,8 +801,8 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), + weight_tensors_fp8, *weight_tensors, - *weight_tensors_fp8, *bias_tensors, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e008bda2cf..10560cdad6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -5,7 +5,6 @@ """LayerNormLinear API""" import os import warnings -import logging from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -48,17 +47,6 @@ from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["LayerNormLinear"] @@ -104,7 +92,6 @@ def forward( ub_name: str, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: - logger = logging.getLogger("LayerNormLinear") # Make sure input dimensions are compatible in_features = ln_weight.numel() assert inp.shape[-1] == in_features, "GEMM not possible" @@ -203,8 +190,6 @@ def forward( ln_out = ln_out_total if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -259,8 +244,6 @@ def forward( dtype=activation_dtype, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias @@ -379,7 +362,6 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("LayerNormLinear") if isinstance(grad_outputs[0], Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ 0 @@ -411,7 +393,7 @@ def backward( ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight.requires_grad) + weight = torch.nn.Parameter(weight, weight.requires_grad) weight.main_grad = main_grad if ctx.ub_overlap_rs_dgrad: @@ -500,8 +482,6 @@ def backward( ub_obj = None if ctx.fp8: - logger.debug("Running backward in FP8") - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) out_index, meta_tensor, out_te_type, out_type = ( @@ -544,8 +524,6 @@ def backward( ) clear_tensor_data(grad_output_c) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - # DGRAD: Evaluated unconditionally to feed into Linear backward _, _, _ = tex.gemm( weight, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2d364271aa..dc9bef645f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -13,6 +13,7 @@ from .base import ( get_workspace, + _ub_communicators, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -569,8 +570,8 @@ def backward( ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - fc1_weight = Parameter(fc1_weight.requires_grad) - fc2_weight = Parameter(fc2_weight.requires_grad) + fc1_weight = Parameter(fc1_weight, fc1_weight.requires_grad) + fc2_weight = Parameter(fc2_weight, fc2_weight.requires_grad) fc1_weight.main_grad = fc1_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad @@ -1297,7 +1298,7 @@ def __init__( self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == "gelu" - and not get_ub("fc1_fprop").is_atomic_gemm() + and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) if tp_group is None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a95fa1c33a..68d333262d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,8 +3,6 @@ # See LICENSE for license information. """Linear API""" -import os -import logging from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -51,17 +49,6 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["Linear"] @@ -97,7 +84,6 @@ def forward( is_first_module_in_mha: bool, fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: - logger = logging.getLogger("Linear") is_input_fp8 = isinstance(inp, Float8Tensor) if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] @@ -158,8 +144,6 @@ def forward( else: inputmat_total = inputmat if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -248,8 +232,6 @@ def forward( dtype=activation_dtype, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias @@ -373,7 +355,6 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("Linear") if isinstance(grad_output, Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[ tex.FP8BwdTensors.GRAD_OUTPUT1 @@ -401,7 +382,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight.requires_grad) + weight = torch.nn.Parameter(weight, weight.requires_grad) weight.main_grad = main_grad tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -450,8 +431,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - logger.debug("Running backward in FP8") - if ctx.is_input_fp8: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, @@ -494,8 +473,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, ) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - dgrad, _, _ = gemm( weight, grad_output, diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index ec3d4fd315..f437f877b4 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -9,11 +9,13 @@ """ from transformer_engine.pytorch.ops.basic import ( + AddInPlace, AllGather, AllReduce, BasicLinear, Bias, Identity, + MakeExtraOutput, ReduceScatter, Reshape, ) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 3621910c8b..1003cc0337 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,10 +4,12 @@ """Single tensor operations supported by the operation fuser.""" +from .add_in_place import AddInPlace from .all_gather import AllGather from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias from .identity import Identity +from .make_extra_output import MakeExtraOutput from .reduce_scatter import ReduceScatter from .reshape import Reshape diff --git a/transformer_engine/pytorch/ops/basic/add_in_place.py b/transformer_engine/pytorch/ops/basic/add_in_place.py new file mode 100644 index 0000000000..041888f5d7 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/add_in_place.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for in-place add.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) + + +class AddInPlace(BasicOperation): + """Add in-place + + This operation requires an extra tensor input to the operation + fuser. The main input is added in-place to the extra input, and a + view of the extra input is output. + + This operation is considered an advanced feature and most users + are discouraged from using it. In-place operations break some + autograd assumptions and they can result in subtle, esoteric bugs. + + Compare to `MakeExtraOutput`, which does a similar operation in + the backward pass. + + """ + + # Operation expects buffer for output tensor + num_extra_inputs: int = 1 + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + output = basic_op_extra_inputs[0][0].detach() + output += input_ + return output, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + return grad_output, [], [(grad_output,)] diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 49923e7af8..826807d1c0 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -12,7 +12,11 @@ import torch -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, gemm +from transformer_engine.pytorch.cpp_extensions import ( + FP8TensorMeta, + fp8_gemm, + gemm, +) from transformer_engine.pytorch.distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -32,6 +36,7 @@ canonicalize_device, canonicalize_dtype, convert_tensor, + devices_match, is_float8_tensor, reshape, ) @@ -308,6 +313,8 @@ def _functional_forward( bias: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + out: Optional[torch.Tensor] = None, + accumulate_into_out: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, @@ -330,6 +337,10 @@ def _functional_forward( Tensor device dtype: torch.dtype, default = default dtype Tensor datatype + out: torch.Tensor, optional + Output tensor + accumulate_into_out: bool, default = `False` + Add result to output tensor instead of overwriting tensor_parallel_mode: {`None`, "column", "row"}, default = `None` Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group @@ -365,19 +376,25 @@ def _functional_forward( # Check device if device is None: - device = weight.device + device = weight.device if out is None else out.device device = canonicalize_device(device) if device.type != "cuda": raise ValueError(f"Only CUDA devices are supported (got {device})") + if out is not None and not devices_match(out.device, device): + raise ValueError( + f"Output tensor has invalid device (expected {device}, got {out.device})" + ) # Check datatype if dtype is None: - dtype = weight.dtype + dtype = weight.dtype if out is None else out.dtype dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + if out is not None and out.dtype != dtype: + raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})") - # Check tensor dims + # Check input tensor dims input_dims = tuple(input.size()) weight_dims = tuple(weight.size()) if len(weight_dims) != 2: @@ -389,6 +406,32 @@ def _functional_forward( "are not compatible" ) + # Check output tensor dims + output_dims: list[int] + if out is None: + output_dims = list(input_dims) + output_dims[0] = -1 + output_dims[-1] = weight_dims[0] + else: + output_dims = list(out.size()) + if len(output_dims) == 0 or weight_dims[0] != output_dims[-1]: + raise ValueError( + f"Output tensor (shape={output_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check if accumulating into output tensor + if accumulate_into_out: + if out is None: + raise ValueError( + "Attempted to accumulate into output tensor without providing output tensor" + ) + if tensor_parallel_mode == "row": + raise ValueError( + "Accumulating into output tensor is not supported with row tensor parallelism" + ) + # Check if FP8 is enabled if with_fp8_compute: if input_fp8_meta is None and not is_float8_tensor(input): @@ -399,9 +442,18 @@ def _functional_forward( input_fp8_meta = None weight_fp8_meta = None output_fp8_meta = None - with_fp8_output = ( - with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None - ) + with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row" + if out is None: + with_fp8_output = with_fp8_output and output_fp8_meta is not None + else: + if is_float8_tensor(out): + if not with_fp8_output: + raise ValueError( + "Output tensor is a Float8Tensor, but FP8 output is not supported" + ) + out._reset_caches() + else: + with_fp8_output = False # Check input tensor x_local = reshape( @@ -476,7 +528,9 @@ def _functional_forward( # Construct output tensor y = None - if with_fp8_output: + if out is not None: + y = reshape(out, (-1, output_dims[-1])) + elif with_fp8_output: fp8_dtype = get_fp8_te_dtype( output_fp8_meta["recipe"], fprop_tensor=True, @@ -506,19 +560,31 @@ def _functional_forward( x_async = None if with_fp8_compute: kwargs = dict( + accumulate=accumulate_into_out, out=y, bias=b, use_bias=(b is not None), ) if with_fp8_output: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=y._fp8_meta_forward, - ) + if y._fp8_meta is None: + # Hackily create FP8TensorMeta if needed + fp8_meta = FP8TensorMeta() + fp8_meta.scale = y._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) + fp8_meta.scale_inv = y._scale_inv + fp8_meta_index = 0 + else: + # Get FP8TensorMeta from Float8Tensor + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=y._fp8_meta_forward, + ) + fp8_meta = y._fp8_meta[fp8_meta_key] + fp8_meta_index = y._fp8_meta_index kwargs.update( dict( out=y._data, - out_index=y._fp8_meta_index, - fp8_meta_tensor=y._fp8_meta[fp8_meta_key], + out_index=fp8_meta_index, + fp8_meta_tensor=fp8_meta, D_dtype=y._fp8_dtype, ) ) @@ -541,6 +607,7 @@ def _functional_forward( x, y.dtype, get_workspace(), + accumulate=accumulate_into_out, out=y, bias=b, use_bias=(b is not None), @@ -553,13 +620,11 @@ def _functional_forward( else: torch.distributed.all_reduce(y, group=tensor_parallel_group) - # Reshape output tensor - output_dims = list(input_dims) - output_dims[0] = -1 - output_dims[-1] = weight_dims[0] - output = reshape(y, output_dims) + # Reshape output tensor if needed + if out is None: + out = reshape(y, output_dims) - return output, x_local, w + return out, x_local, w @staticmethod def _functional_backward( @@ -573,6 +638,10 @@ def _functional_backward( weight_requires_grad: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + grad_weight: Optional[torch.Tensor] = None, + accumulate_into_grad_weight: bool = False, + grad_input: Optional[torch.Tensor] = None, + accumulate_into_grad_input: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, @@ -581,8 +650,6 @@ def _functional_backward( weight_fp8_meta: Optional[dict[str, Any]] = None, grad_output_fp8_meta: Optional[dict[str, Any]] = None, grad_input_fp8_meta: Optional[dict[str, Any]] = None, - accumulate_into_grad_weight: bool = False, - grad_weight: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Functional API for backward pass @@ -608,6 +675,14 @@ def _functional_backward( Tensor device dtype: torch.dtype, default = default dtype Tensor datatype + grad_weight: torch.Tensor, optional + Loss gradient w.r.t. weight tensor + accumulate_into_grad_weight: bool, default = `False` + Add result to weight grad instead of overwriting + grad_input: torch.Tensor, optional + Loss gradient w.r.t. input tensor + accumulate_into_grad_input: bool, default = `False` + Add result to input grad instead of overwriting tensor_parallel_mode: {`None`, "column", "row"}, default = `None` Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group @@ -632,10 +707,6 @@ def _functional_backward( grad_output_fp8_meta: dict, optional FP8 metadata for casting loss gradient w.r.t. input tensor to FP8 - accumulate_into_grad_weight: bool, default = `False` - Accumulate into weight grad instead of overwriting - grad_weight: torch.Tensor, optional - Loss gradient w.r.t. weight tensor Returns ------- @@ -678,6 +749,34 @@ def _functional_backward( f"and weight tensor (shape={weight_dims}) " "are not compatible" ) + if grad_input is not None and tuple(grad_input.size()) != input_dims: + raise ValueError( + f"Grad input tensor (shape={tuple(grad_input.size())}) " + f"does not match expected shape ({input_dims})" + ) + + # Check grad input tensor + if not input_requires_grad: + grad_input = None + if grad_input is not None and not devices_match(grad_input.device, device): + raise ValueError( + f"Grad input tensor has invalid device (expected {device}, got {grad_input.device})" + ) + if grad_input is not None and grad_input.dtype != dtype: + raise ValueError( + f"Grad input tensor has invalid dtype (expected {dtype}, got {grad_input.dtype})" + ) + if accumulate_into_grad_input: + if grad_input is None: + raise ValueError( + "Attempted to accumulate into grad input tensor " + "without providing grad input tensor" + ) + if tensor_parallel_mode == "column": + raise ValueError( + "Accumulating into grad input tensor " + "is not supported with column tensor parallelism" + ) # Check if FP8 is enabled if with_fp8_compute: @@ -689,11 +788,19 @@ def _functional_backward( grad_output_fp8_meta = None grad_input_fp8_meta = None with_fp8_grad_input = ( - with_fp8_compute - and input_requires_grad - and tensor_parallel_mode != "column" - and grad_input_fp8_meta is not None + with_fp8_compute and input_requires_grad and tensor_parallel_mode != "column" ) + if grad_input is None: + with_fp8_grad_input = with_fp8_grad_input and grad_input_fp8_meta is not None + else: + if is_float8_tensor(grad_input): + if not with_fp8_grad_input: + raise ValueError( + "Grad input tensor is a Float8Tensor, but FP8 output is not supported" + ) + grad_input._reset_caches() + else: + with_fp8_grad_input = False # Check grad output tensor dy_async = None @@ -806,7 +913,9 @@ def _functional_backward( w = w.from_float8() # Construct grad input tensor - if with_fp8_grad_input: + if grad_input is not None: + dx = reshape(grad_input, (-1, input_dims[-1])) + elif with_fp8_grad_input: fp8_dtype = get_fp8_te_dtype( grad_input_fp8_meta["recipe"], fprop_tensor=False, @@ -835,16 +944,32 @@ def _functional_backward( _wait_async(dy_async) dy_async = None if with_fp8_compute: - kwargs = dict(out=dx) + kwargs = dict( + accumulate=accumulate_into_grad_input, + out=dx, + ) if with_fp8_grad_input: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dx._fp8_meta_forward, - ) + if dx._fp8_meta is None: + # Hackily create FP8TensorMeta if needed + fp8_meta = FP8TensorMeta() + fp8_meta.scale = dx._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty( + 1, 1, dtype=torch.float32, device=device + ) + fp8_meta.scale_inv = dx._scale_inv + fp8_meta_index = 0 + else: + # Get FP8TensorMeta from Float8Tensor + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dx._fp8_meta_forward, + ) + fp8_meta = dx._fp8_meta[fp8_meta_key] + fp8_meta_index = dx._fp8_meta_index kwargs.update( dict( out=dx._data, - out_index=dx._fp8_meta_index, - fp8_meta_tensor=dx._fp8_meta[fp8_meta_key], + out_index=fp8_meta_index, + fp8_meta_tensor=fp8_meta, D_dtype=dx._fp8_dtype, ) ) @@ -867,6 +992,7 @@ def _functional_backward( dy, dx.dtype, get_workspace(), + accumulate=accumulate_into_grad_input, layout="NN", out=dx, ) @@ -936,8 +1062,7 @@ def _functional_backward( _wait_async(dy_async) _wait_async(x_async) _wait_async(dx_async) - grad_input = None - if dx is not None: + if dx is not None and grad_input is None: grad_input = reshape(dx, input_dims) return grad_input, grad_weight @@ -1027,6 +1152,8 @@ def op_backward( weight_requires_grad=ctx.weight_requires_grad, device=self.device, dtype=self.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, @@ -1034,8 +1161,6 @@ def op_backward( weight_fp8_meta=ctx.weight_fp8_meta, grad_output_fp8_meta=ctx.grad_output_fp8_meta, grad_input_fp8_meta=ctx.grad_input_fp8_meta, - accumulate_into_grad_weight=accumulate_into_main_grad, - grad_weight=grad_weight, ) # Clear input tensor if possible diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py new file mode 100644 index 0000000000..db1651c184 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Make extra tensor output in operation fuser.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) + + +class MakeExtraOutput(BasicOperation): + """Make extra output in operation fuser + + If this operation is included in the operation fuser, then the + operation fuser will return the intermediate tensor as an extra + tensor output. In the backward pass, the gradient is directly + accumulated into the gradient w.r.t. the extra output. + + This operation is considered an advanced feature and most users + are discouraged from using it. In-place operations break some + autograd assumptions and they can result in subtle, esoteric bugs. + + Compare to `AddInPlace`, which does a similar operation in the + backward pass. + + """ + + # Operation expects buffer for output tensor + num_extra_outputs: int = 1 + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + return input_, [(input_,)] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + grad_input = basic_op_grad_extra_outputs[0][0] + grad_input += grad_output + return grad_input, [], [()] diff --git a/transformer_engine/pytorch/ops/fused_forward/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py similarity index 52% rename from transformer_engine/pytorch/ops/fused_forward/__init__.py rename to transformer_engine/pytorch/ops/fused/__init__.py index ed523a067a..bd832254d8 100644 --- a/transformer_engine/pytorch/ops/fused_forward/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -4,7 +4,15 @@ """Compound tensor operation supported by the operation fuser.""" -from .linear_bias_activation import ( +from .backward_linear_add import ( + BackwardLinearAdd, + fuse_backward_linear_add, +) +from .forward_linear_bias_activation import ( ForwardLinearBiasActivation, fuse_forward_linear_bias_activation, ) +from .forward_linear_bias_add import ( + ForwardLinearBiasAdd, + fuse_forward_linear_bias_add, +) diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py new file mode 100644 index 0000000000..138eca3d96 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused backward dgrad GEMM + add.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput +from transformer_engine.pytorch.ops.op import ( + FusedOperation, + FusibleOperation, + OperationContext, +) +from ...utils import clear_tensor_data + + +class BackwardLinearAdd(FusedOperation): + """Fused backward dgrad GEMM + add + + Column tensor parallelism is not supported since that requires + communication immediately after the dgrad GEMM. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + backward_add: MakeExtraOutput, + ) -> None: + super().__init__((linear, backward_add)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + linear_op = self.basic_ops[0] + linear_op_ctx = basic_op_ctxs[0] + + # Saved tensors from forward pass + (x_local,) = linear_op_ctx.saved_tensors + + # wgrad fusion + accumulate_into_main_grad = linear_op._accumulate_into_main_grad + grad_weight = None + if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + if not hasattr(linear_op.weight, "main_grad"): + raise RuntimeError( + "BasicLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weight = linear_op.weight.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Linear backward pass + grad_input = basic_op_grad_extra_outputs[1][0] + grad_input, grad_weight = BasicLinear._functional_backward( + grad_output=grad_output, + input=x_local, + weight=linear_op.weight, + input_dims=linear_op_ctx.input_dims, + weight_dims=linear_op.weight.size(), + input_requires_grad=linear_op_ctx.input_requires_grad, + weight_requires_grad=linear_op_ctx.weight_requires_grad, + device=linear_op.device, + dtype=linear_op.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, + grad_input=grad_input, + accumulate_into_grad_input=True, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_fp8_compute=linear_op_ctx.with_fp8_compute, + weight_fp8_meta=linear_op_ctx.weight_fp8_meta, + grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, + grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, + ) + if accumulate_into_main_grad: + grad_weight = None + + # Clear input tensor if possible + if linear_op_ctx.has_prev_op: + clear_tensor_data(x_local) + + return grad_input, [(grad_weight,), ()], [(), ()] + + +def fuse_backward_linear_add( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fused backward dgrad GEMM + add + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, BasicLinear): + continue + if op.tensor_parallel_mode == "column": + # Row tensor-parallelism requires communication after the + # GEMM + continue + + # Check if second op is "make extra output" + op, _ = ops[0] + if not isinstance(op, MakeExtraOutput): + continue + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = BackwardLinearAdd( + linear=window[0][0], + backward_add=window[1][0], + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py similarity index 93% rename from transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py rename to transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 1504dc4a53..5fd52405e4 100644 --- a/transformer_engine/pytorch/ops/fused_forward/linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -2,9 +2,10 @@ # # See LICENSE for license information. -"""Fused operation for GEMM, bias, activation in the forward pass.""" +"""Fused operation for forward GEMM + bias + activation.""" from __future__ import annotations +from collections.abc import Iterable from typing import Any, Optional import torch @@ -20,7 +21,7 @@ class ForwardLinearBiasActivation(FusedOperation): - """Fused GEMM, bias, activation in the forward pass + """Fused forward GEMM + bias + activation Bias and activation are both optional. Row tensor parallelism is not supported since that requires communication immediately after @@ -60,10 +61,12 @@ def fuser_forward( self, basic_op_ctxs: list[OperationContext], input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]], basic_op_kwargs: list[dict[str, Any]], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations idx = self._op_idxs["linear"] @@ -128,13 +131,13 @@ def fuser_forward( linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None - return output + return output, [() for _ in range(len(self.basic_ops))] def fuse_forward_linear_bias_activation( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse GEMM, bias, activation in the forward pass + """Fuse forward GEMM + bias + activation Parameters ---------- diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py new file mode 100644 index 0000000000..6ddee2849a --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -0,0 +1,196 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for forward GEMM + bias + add.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + FusedOperation, + FusibleOperation, + OperationContext, +) + + +class ForwardLinearBiasAdd(FusedOperation): + """Fused forward GEMM + bias + add + + Bias is optional. Row tensor parallelism is not supported since + that requires communication immediately after the GEMM. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + add: AddInPlace, + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = dict( + linear=0, + bias=None, + add=None, + ) + ops = [linear] + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + op_idxs["add"] = len(ops) + ops.append(add) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + if self._op_idxs["bias"] is None: + bias_op = None + bias = None + else: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + bias = bias_op.bias + if basic_op_kwargs[idx]: + raise ValueError("Bias operation forward does not expect keyword arguments") + + # FP8 metadata + with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + if with_fp8_compute: + input_fp8_meta = linear_op.get_fp8_meta("input") + weight_fp8_meta = linear_op.get_fp8_meta("param") + grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + prev_op = basic_op_prev_ops[0] + if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: + grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + + # Linear forward + output = basic_op_extra_inputs[self._op_idxs["add"]][0] + output, x_local, _ = BasicLinear._functional_forward( + input=input_, + weight=linear_op.weight, + bias=bias, + device=linear_op.device, + dtype=linear_op.dtype, + out=output, + accumulate_into_out=True, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_fp8_compute=with_fp8_compute, + input_fp8_meta=input_fp8_meta, + weight_fp8_meta=weight_fp8_meta, + output_fp8_meta=output_fp8_meta, + ) + + # Save state for backward pass + linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.with_fp8_compute = with_fp8_compute + linear_op_ctx.weight_fp8_meta = weight_fp8_meta + linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta + linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.input_dims = input_.size() + linear_op_ctx.input_requires_grad = input_.requires_grad + linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad + linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + + return output, [() for _ in range(len(self.basic_ops))] + + +def fuse_forward_linear_bias_add( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fuse forward GEMM + bias + add + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, BasicLinear): + continue + if op.tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after the + # GEMM + continue + linear = op + op, _ = ops[0] + + # Check if next op is bias + bias = None + if isinstance(op, Bias): + bias = op + window.extend(ops[:1]) + ops = ops[1:] + if len(ops) == 0: + continue + op, _ = ops[0] + + # Check if next op is add in-place + if not isinstance(op, AddInPlace): + continue + add = op + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = ForwardLinearBiasAdd( + linear=linear, + bias=bias, + add=add, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 06ea608ed8..a7c99c592d 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -16,11 +16,18 @@ FusibleOperation, OperationContext, ) -from transformer_engine.pytorch.ops.fused_forward import ( +from transformer_engine.pytorch.ops.fused import ( + fuse_backward_linear_add, fuse_forward_linear_bias_activation, + fuse_forward_linear_bias_add, ) +def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: + """Split tuple at index""" + return t[:idx], t[idx:] + + class _OperationFuserAutogradFunction(torch.autograd.Function): """Autograd function for a pipeline of operations @@ -38,8 +45,10 @@ def forward( backward_ops: list[tuple[FusibleOperation, list[int]]], basic_ops: list[BasicOperation], basic_op_kwargs: list[dict[str, Any]], - *params: torch.nn.Parameter, - ) -> torch.Tensor: + num_params: int, + num_extra_inputs: int, + *params_and_extra_inputs: torch.nn.Parameter, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass Parameters @@ -60,39 +69,82 @@ def forward( Basic operations basic_op_kwargs: list of dict Keyword arguments to BasicOperation - *params: torch.nn.Parameter - Parameters in operation pipeline + num_params: int + Number of parameter tensors to include in autograd graph. + *params_and_extra_inputs: torch.Tensor + Other tensor inputs to include in autograd graph. Consists + of parameter tensors, followed by extra operation inputs. + + Returns + ------- + Output tensor(s). If none of the operations have any extra + tensor outputs, then the pipeline's output tensor is returned. + Otherwise, a tuple with the pipeline's output tensor and extra + tensor outputs is returned. """ # Operation autograd contexts basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] + # Unflatten list of parameters and extra tensor inputs + if len(params_and_extra_inputs) != num_params + num_extra_inputs: + raise ValueError( + f"Expected {num_params + num_extra_inputs} extra tensor arguments " + f"({num_params} parameters, {num_extra_inputs} extra inputs), " + f"but got {len(params_and_extra_inputs)}" + ) + _, extra_inputs = _split_tuple(params_and_extra_inputs, num_params) + basic_op_extra_inputs = [] + for op in basic_ops: + xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) + basic_op_extra_inputs.append(xs) + # Apply forward ops x = input_ requires_grad = x.requires_grad + extra_outputs = [None for _ in range(len(basic_ops))] for op, basic_op_idxs in forward_ops: # Forward op + extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] next_ops = [ basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs ] - x = op.fuser_forward( + x, fused_op_extra_outputs = op.fuser_forward( [basic_op_ctxs[idx] for idx in basic_op_idxs], x, - prev_ops, - next_ops, - [basic_op_kwargs[idx] for idx in basic_op_idxs], + basic_op_extra_inputs=extra_inputs, + basic_op_prev_ops=prev_ops, + basic_op_next_ops=next_ops, + basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], ) + for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): + extra_outputs[idx] = ys # Check if backward op is required if not requires_grad: requires_grad = any(param.requires_grad for param in op.parameters()) + if not requires_grad: + requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) for idx in basic_op_idxs: basic_op_ctxs[idx]._requires_grad = requires_grad x.requires_grad_(requires_grad=requires_grad) + # Flatten list of extra outputs + extra_outputs_flat = [] + for idx, ys in enumerate(extra_outputs): + ys = list(ys) + num_extra_outputs = basic_ops[idx].num_extra_outputs + if len(ys) != num_extra_outputs: + raise RuntimeError( + f"Expected op {idx} to generate " + "{num_extra_outputs} extra inputs, " + f"but got {len(ys)}" + ) + extra_outputs_flat.extend(ys) + # Flatten list of saved tensors to_save = [] for ctx in basic_op_ctxs: @@ -108,8 +160,13 @@ def forward( func_ctx.backward_ops = backward_ops func_ctx.basic_ops = basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs + func_ctx.num_params = num_params + func_ctx.num_extra_inputs = num_extra_inputs + func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + if extra_outputs_flat: + return x, *extra_outputs_flat return x @staticmethod @@ -117,6 +174,7 @@ def forward( def backward( func_ctx: Any, grad_output: torch.Tensor, + *grad_extra_outputs: torch.Tensor, ) -> tuple[Optional[torch.Tensor], ...]: """Backward pass""" @@ -126,15 +184,25 @@ def backward( basic_op_ctxs = func_ctx.basic_op_ctxs # Unflatten list of saved tensors - saved_tensors = func_ctx.saved_tensors for ctx in basic_op_ctxs: - ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)] + ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)] ctx._saved_tensors_range = None - del saved_tensors + + # Unflatten list of extra tensor output grads + if len(grad_extra_outputs) != func_ctx.num_extra_outputs: + raise ValueError( + f"Expected grads for {func_ctx.num_extra_outputs} extra tensor outputs, " + f"but got {len(grad_extra_outputs)}" + ) + basic_op_grad_extra_outputs = [] + for op in basic_ops: + dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs) + basic_op_grad_extra_outputs.append(dys) # Apply backward ops dx = grad_output grad_params = [None for _ in range(len(basic_ops))] + grad_extra_inputs = [None for _ in range(len(basic_ops))] for op, basic_op_idxs in backward_ops: # Stop if no more gradients are required @@ -143,13 +211,17 @@ def backward( break # Backward op - dx, fused_op_dparams = op.fuser_backward( + grad_extra_outputs = [basic_op_grad_extra_outputs[idx] for idx in basic_op_idxs] + dx, fused_op_grad_params, fused_op_grad_extra_inputs = op.fuser_backward( [basic_op_ctxs[idx] for idx in basic_op_idxs], dx, + basic_op_grad_extra_outputs=grad_extra_outputs, ) - for idx, basic_op_dparams in zip(basic_op_idxs, fused_op_dparams): - grad_params[idx] = basic_op_dparams + for idx, dparams in zip(basic_op_idxs, fused_op_grad_params): + grad_params[idx] = dparams basic_op_ctxs[idx].saved_tensors = None + for idx, dxs in zip(basic_op_idxs, fused_op_grad_extra_inputs): + grad_extra_inputs[idx] = dxs # Flatten list of parameter gradients grad_params_flat = [] @@ -166,6 +238,22 @@ def backward( ) grad_params_flat.extend(dparams) + # Flatten list of parameter gradients + grad_extra_inputs_flat = [] + for idx, dxs in enumerate(grad_extra_inputs): + num_extra_inputs = basic_ops[idx].num_extra_inputs + if dxs is None: + dxs = [None for _ in range(num_extra_inputs)] + else: + dxs = list(dxs) + if len(dxs) != num_extra_inputs: + raise RuntimeError( + f"Expected op {idx} to generate grads " + f"for {num_extra_inputs} extra inputs, " + f"but got {len(dxs)}" + ) + grad_extra_inputs_flat.extend(dxs) + # Update FP8 scaling factors if func_ctx.is_first_module and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -176,7 +264,10 @@ def backward( None, # backward_ops None, # basic_ops None, # basic_op_kwargs - *grad_params_flat, # params + None, # num_params + None, # num_extra_inputs + *grad_params_flat, + *grad_extra_inputs_flat, ) @@ -208,6 +299,9 @@ def __init__( self._num_basic_ops: int = len(basic_ops) self._basic_ops: list[BasicOperation] = basic_ops + # Number of extra tensor inputs + self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops) + # Ops for forward and backward pass self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]] @@ -224,6 +318,7 @@ def _fuse_forward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in forward pass""" + ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_activation(ops) return ops @@ -233,6 +328,7 @@ def _fuse_backward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in backward pass""" + ops = fuse_backward_linear_add(ops) return ops def fuse_ops(self) -> None: @@ -243,8 +339,9 @@ def fuse_ops(self) -> None: def __call__( self, input: torch.Tensor, # pylint: disable=redefined-builtin + *extra_inputs: torch.Tensor, basic_op_kwargs: Optional[list[dict[str, Any]]] = None, - ) -> torch.Tensor: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: # Initialization before forward pass for op in self._basic_ops: @@ -255,9 +352,7 @@ def __call__( basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] # Flatten list of parameters - params = [] - for op in self._basic_ops: - params.extend(op.parameters()) + params = [param for op in self._basic_ops for param in op.parameters()] # Fuser forward pass return _OperationFuserAutogradFunction.apply( @@ -266,5 +361,8 @@ def __call__( self._backward_ops, self._basic_ops, basic_op_kwargs, + len(params), + self._num_extra_inputs, *params, + *extra_inputs, ) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 3d90d07b84..47c6567056 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -67,10 +67,12 @@ def fuser_forward( self, basic_op_ctxs: list[OperationContext], input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]], basic_op_kwargs: list[dict[str, Any]], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: """Forward pass This op is either a basic op or the fusion of basic ops, so @@ -82,24 +84,27 @@ def fuser_forward( Parameters ---------- basic_op_ctxs: list of OperationContext - Contexts for corresponding basic operations + Contexts for basic operations input_: torch.Tensor Input tensor + basic_op_extra_inputs: list of torch.Tensor + Extra tensor inputs to basic operations basic_op_prev_ops: list of BasicOperation - Basic operations that preceed each of the corresponding - basic operations (or `None` if corresponding basic op is - first) + Basic operations that preceed this operation's basic + operations basic_op_next_ops: list of BasicOperation - Basic operations that follow each of the corresponding - basic operations (or `None` if corresponding basic op is - last) + Basic operations that follow this operation's basic + operations basic_op_kwargs: list of dict - Keyword arguments to forward functions of corresponding - basic operations + Keyword arguments to forward functions of basic + operations. Returns ------- - torch.Tensor: Output tensor. + torch.Tensor: + Output tensor. + Iterable of torch.Tensor: + Extra tensor outputs from basic operations. """ raise NotImplementedError( @@ -110,7 +115,13 @@ def fuser_backward( self, basic_op_ctxs: list[OperationContext], grad_output: torch.Tensor, - ) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]: + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: """Backward pass This op is either a basic op or the fusion of basic ops, so @@ -122,24 +133,21 @@ def fuser_backward( Parameters ---------- basic_op_ctxs: list of OperationContext - Contexts for corresponding basic operations. + Contexts for basic operations grad_output: torch.Tensor - Loss gradient w.r.t. operation output. - basic_op_prev_ops: list of BasicOperation - Basic operations that preceed each of the corresponding - basic operations (or `None` if corresponding basic op is - first) - basic_op_next_ops: list of BasicOperation - Basic operations that follow each of the corresponding - basic operations (or `None` if corresponding basic op is - last) + Loss gradient w.r.t. operation output + basic_op_grad_extra_outputs: list of tuple of torch.Tensor + Loss gradients w.r.t. extra tensor outputs from basic + operations. Returns ------- torch.Tensor: Loss gradient w.r.t. operation input Iterable of iterable of torch.Tensor: - Loss gradients w.r.t. parameters for corresponding basic + Loss gradients w.r.t. parameters for basic operations + Iterable of iterable of torch.Tensor: + Loss gradients w.r.t. extra tensor inputs to basic operations """ @@ -156,6 +164,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): """ + # Number of extra tensor inputs + num_extra_inputs: int = 0 + # Number of extra tensor outputs + num_extra_outputs: int = 0 + def __init__(self) -> None: super().__init__() @@ -297,6 +310,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, + *, prev_op: Optional[BasicOperation] = None, next_op: Optional[BasicOperation] = None, **kwargs: Any, @@ -309,6 +323,10 @@ def op_forward( Context to coordinate between forward and backward passes input_: torch.Tensor Input tensor + prev_op: BasicOperation, optional + Basic operation that preceeds this operation + next_op: BasicOperation, optional + Basic operation that follows this operation Returns ------- @@ -345,35 +363,63 @@ def fuser_forward( self, basic_op_ctxs: list[OperationContext], input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_prev_ops: list[Optional[BasicOperation]], basic_op_next_ops: list[Optional[BasicOperation]], basic_op_kwargs: list[dict[str, Any]], - ) -> torch.Tensor: - return self.op_forward( + ) -> tuple[torch.Tensor, list[tuple[()]]]: + if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It should override `fuser_forward` instead of `op_forward`." + ) + output = self.op_forward( basic_op_ctxs[0], input_, - basic_op_prev_ops[0], - basic_op_next_ops[0], + prev_op=basic_op_prev_ops[0], + next_op=basic_op_next_ops[0], **basic_op_kwargs[0], ) + return output, [()] def fuser_backward( self, basic_op_ctxs: list[OperationContext], grad_output: torch.Tensor, - ) -> tuple[torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]]]: + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[Iterable[Optional[torch.Tensor]]], + list[tuple[()]], + ]: + if self.num_extra_inputs > 0 or self.num_extra_outputs > 0: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It should override `fuser_backward` instead of `op_backward`." + ) grad_input, grad_params = self.op_backward(basic_op_ctxs[0], grad_output) - return grad_input, [grad_params] + return grad_input, [grad_params], [()] def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin + *extra_inputs: torch.Tensor, **kwargs: Any, - ) -> torch.Tensor: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply operation""" from .fuser import OperationFuser - return OperationFuser([self], fuse_ops=False)(input, [kwargs]) + return OperationFuser([self], fuse_ops=False)( + input, + *extra_inputs, + basic_op_kwargs=[kwargs], + ) class FusedOperation(FusibleOperation): @@ -417,6 +463,7 @@ def pre_forward(self) -> None: def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin + *extra_inputs: torch.Tensor, basic_op_kwargs: Optional[list[dict[str, Any]]] = None, ) -> torch.Tensor: """Apply operation""" @@ -424,4 +471,8 @@ def forward( basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] from .fuser import OperationFuser - return OperationFuser([self], fuse_ops=False)(input, basic_op_kwargs) + return OperationFuser([self], fuse_ops=False)( + input, + *extra_inputs, + basic_op_kwargs=basic_op_kwargs, + ) diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index 57b4036bba..c5e25fe1f2 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -144,28 +144,44 @@ def _make_module_groups( modules: Iterable[torch.nn.Module], ) -> list[OperationFuser | torch.nn.Module]: """Make list of modules, with fusible operations grouped together""" - module_groups = [] - fusible_ops = [] - - def maybe_add_fuser(): - nonlocal fusible_ops - if fusible_ops: - module_groups.append(OperationFuser(fusible_ops, fuse_ops=True)) - fusible_ops = [] + # Group fusible operations together + groups = [] for module in modules: if isinstance(module, FusibleOperation): - fusible_ops.append(module) + if not groups or not isinstance(groups[-1], list): + groups.append([]) + groups[-1].append(module) else: - maybe_add_fuser() - module_groups.append(module) - maybe_add_fuser() - return module_groups + groups.append(module) + for idx, group in enumerate(groups): + if isinstance(group, list): + groups[idx] = OperationFuser(group, fuse_ops=True) + + # Check if operations expect extra input or output tensors + # Note: If any op has extra inputs or outputs, then the entire + # Sequential must be made up of TE ops. + if len(groups) > 1: + ops = [] + for group in groups: + if isinstance(group, OperationFuser): + ops.extend(group._basic_ops) + num_extra_inputs = sum(op.num_extra_inputs for op in ops) + num_extra_outputs = sum(op.num_extra_outputs for op in ops) + if num_extra_inputs > 0 or num_extra_outputs > 0: + raise RuntimeError( + f"`Sequential` expects {num_extra_inputs} extra inputs " + f"and {num_extra_outputs} extra outputs, " + "but it contains non-fusible operations" + ) + + return groups def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin - ) -> torch.Tensor: + *extra_inputs: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass""" # Create module groups if needed @@ -175,5 +191,5 @@ def forward( # Forward pass for each module group x = input for module_group in self._module_groups: - x = module_group(x) + x = module_group(x, *extra_inputs) return x diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 8cbe720a74..fc9bdc304a 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -8,6 +8,7 @@ multi_tensor_l2norm, multi_tensor_unscale_l2norm, multi_tensor_adam, + multi_tensor_adam_fp8, multi_tensor_adam_capturable, multi_tensor_adam_capturable_master, multi_tensor_sgd, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 91ce502390..322b93a1d8 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -5,9 +5,27 @@ """Fused Adam optimizer.""" import torch import transformer_engine_torch as tex +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from .multi_tensor_apply import multi_tensor_applier +def get_fp8_meta(fp8_tensor): + """FP8 metadata getter.""" + if fp8_tensor._fp8_meta is None: + raise RuntimeError("FP8 meta data is not initialized.") + + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=fp8_tensor._fp8_meta_forward, + ) + + fp8_meta_index = fp8_tensor._fp8_meta_index + scale = fp8_tensor._fp8_meta[fp8_meta_key].scale[fp8_meta_index] + amax = fp8_tensor._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] + scale_inv = fp8_tensor._scale_inv + return scale, amax, scale_inv + + class FusedAdam(torch.optim.Optimizer): """Implements Adam algorithm. @@ -50,9 +68,11 @@ class FusedAdam(torch.optim.Optimizer): method is called. (default: True) capturable (bool, optional): whether to use the version of the optimizer that can be used with CUDA Graphs. (default: False) - master_weights (bool, optional): whether to maintain FP32 master weights - in the optimizer with FP16 mixed precision training, currently can - only be used with capturable set to True. (default: False) + master_weights (list of torch.Tensor, optional): master weights to use + for mixed precision training. If provided, the optimizer will update + the master weights and then cast the master weights to the model weights. + If not provided, the optimizer will update the model weights directly. + (default: None) .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -72,15 +92,12 @@ def __init__( amsgrad=False, set_grad_none=True, capturable=False, - master_weights=False, + master_weights=None, ): if amsgrad: raise RuntimeError("FusedAdam does not support the AMSGrad variant.") - if master_weights and not capturable: - raise RuntimeError( - "Master weights is currently only supported with the capturable version." - ) + # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr defaults = dict( @@ -95,20 +112,10 @@ def __init__( self.set_grad_none = set_grad_none self.capturable = capturable - self.master_weights = master_weights - # Create full precision master weights - self.param_groups_master = [] - for _, pg in enumerate(self.param_groups): - param_list = pg["params"] - self.param_groups_master.append( - { - "params": [ - p.clone().detach().float() if self.master_weights else None - for p in param_list - ], - } - ) + if master_weights is not None: + assert isinstance(master_weights, list), "master_weights must be a list if provided" + self.master_weights = master_weights if capturable: for idx, group in enumerate(self.param_groups): @@ -123,6 +130,7 @@ def __init__( # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") self.multi_tensor_adam = tex.multi_tensor_adam + self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8 self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master @@ -147,7 +155,9 @@ def step(self, closure=None, grad_scaler=None): if closure is not None: loss = closure() - for group, group_master in zip(self.param_groups, self.param_groups_master): + master_param_idx = 0 + + for group in self.param_groups: if len(group["params"]) == 0: continue device = group["params"][0].device @@ -166,51 +176,131 @@ def step(self, closure=None, grad_scaler=None): ) # create lists for multi-tensor apply - g_16, p_16, m_16, v_16 = [], [], [], [] - g_bf, p_bf, m_bf, v_bf = [], [], [], [] - g_32, p_32, m_32, v_32 = [], [], [], [] - p_16_master = [] - p_32_master = [] - - for p, p_master in zip(group["params"], group_master["params"]): - if p.grad is None: - continue - if p.grad.data.is_sparse: - raise RuntimeError("FusedAdam does not support sparse gradients.") - + p_main_of_fp8_model = [] + p_main_of_f16_model = [] + g_of_fp8_model = [] + g_of_f16_model = [] + g_of_f32_model = [] + m_of_fp8_model = [] + m_of_f16_model = [] + m_of_f32_model = [] + v_of_fp8_model = [] + v_of_f16_model = [] + v_of_f32_model = [] + p_fp8_model = [] + p_f16_model = [] + p_f32_model = [] + # fp8 meta + scales = [] + amaxes = [] + scale_invs = [] + + # Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is. + out_dtype = tex.DType.kFloat32 + + has_fp16 = False + has_bf16 = False + + for p in group["params"]: state = self.state[p] + # State initialization if len(state) == 0: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p.data).float() # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p.data).float() + # Master weights + if self.master_weights and p.dtype != torch.float32: + # model weights can be fp32/bf16/fp16/fp8 + # If it's fp32, it has no corresponding master weights + state["master_param"] = self.master_weights[master_param_idx] + master_param_idx += 1 + assert ( + state["master_param"].shape == p.shape + ), "Master weights shape must match model weights shape" + else: + state["master_param"] = None + + p_master = state["master_param"] + p_grad = p.grad + + if self.master_weights and p_master is not None and p_master.grad is not None: + p_grad = p_master.grad + + if p_grad is None: + continue + if p_grad.data.is_sparse: + raise RuntimeError("FusedAdam does not support sparse gradients.") - if p.dtype == torch.float16: + if isinstance(p, Float8Tensor): + out_dtype = p._fp8_dtype + p_fp8_model.append(p._data.data) + scale, amax, scale_inv = get_fp8_meta(p) + scales.append(scale) + amaxes.append(amax) + scale_invs.append(scale_inv) if self.master_weights: - p_16_master.append(p_master.data) - g_16.append(p.grad.data) - p_16.append(p.data) - m_16.append(state["exp_avg"]) - v_16.append(state["exp_avg_sq"]) - elif p.dtype == torch.bfloat16: - g_bf.append(p.grad) - p_bf.append(p) - m_bf.append(state["exp_avg"]) - v_bf.append(state["exp_avg_sq"]) - elif p.dtype == torch.float32: + p_main_of_fp8_model.append(p_master.data) + g_of_fp8_model.append(p_grad.data) + m_of_fp8_model.append(state["exp_avg"]) + v_of_fp8_model.append(state["exp_avg_sq"]) + elif p.dtype in [torch.float16, torch.bfloat16]: + has_fp16 = has_fp16 or p.dtype == torch.float16 + has_bf16 = has_bf16 or p.dtype == torch.bfloat16 + p_f16_model.append(p.data) if self.master_weights: - p_32_master.append(p_master.data) - g_32.append(p.grad.data) - p_32.append(p.data) - m_32.append(state["exp_avg"]) - v_32.append(state["exp_avg_sq"]) + p_main_of_f16_model.append(p_master.data) + g_of_f16_model.append(p_grad.data) + m_of_f16_model.append(state["exp_avg"]) + v_of_f16_model.append(state["exp_avg_sq"]) + elif p.dtype == torch.float32: + p_f32_model.append(p.data) + g_of_f32_model.append(p_grad.data) + m_of_f32_model.append(state["exp_avg"]) + v_of_f32_model.append(state["exp_avg_sq"]) else: - raise RuntimeError("FusedAdam only support fp16 and fp32.") + raise RuntimeError("FusedAdam only support model weights in fp16/bf16 and fp8") + + if self.capturable and len(p_fp8_model) > 0: + raise RuntimeError( + "FusedAdam does not support FP8 model weights with capturable=True." + ) + + if has_fp16 and has_bf16: + # simple to add support for this, but not needed for now + raise RuntimeError( + "FusedAdam does not support a mix of float16 and bfloat16 model weights." + ) + + def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=None): + # Closures defined in a loop can have unexpected + # behavior when called outside the loop. However, this + # function is called in the same loop iteration as it + # is defined. + # pylint: disable=cell-var-from-loop + inv_scale_arg = () if inv_scale is None else (inv_scale,) + out_dtype_arg = () if out_dtype is None else (out_dtype,) + multi_tensor_applier( + adam_func, + self._dummy_overflow_buf, + tensor_lists, + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + *inv_scale_arg, + *out_dtype_arg, + ) - # If the optimizer is capturable, then if there's a grad scaler it works - # on the GPU + a different multi_tensor_applier should be called if self.capturable: + # If the optimizer is capturable, then if there's a grad scaler it works + # on the GPU + a different multi_tensor_applier should be called + # overflow check of gradients found_inf = ( grad_scaler._check_inf_per_device(self)[device] @@ -228,113 +318,76 @@ def step(self, closure=None, grad_scaler=None): scale = torch.ones((1,), device=device) inv_scale = torch.ones((1,), device=device) - if len(g_16) > 0: - multi_tensor_applier( - ( - self.multi_tensor_adam_capturable_master - if self.master_weights - else self.multi_tensor_adam_capturable - ), - self._dummy_overflow_buf, - ( - [g_16, p_16, m_16, v_16, p_16_master] - if self.master_weights - else [g_16, p_16, m_16, v_16] - ), - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - - if len(g_bf) > 0: - multi_tensor_applier( - self.multi_tensor_adam_capturable, - self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - - if len(g_32) > 0: - multi_tensor_applier( - ( - self.multi_tensor_adam_capturable_master - if self.master_weights - else self.multi_tensor_adam_capturable - ), - self._dummy_overflow_buf, - ( - [g_32, p_32, m_32, v_32, p_32_master] - if self.master_weights - else [g_32, p_32, m_32, v_32] - ), - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - else: - if len(g_16) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) - - if len(g_bf) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) - - if len(g_32) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) + if self.master_weights: + if len(p_f16_model) > 0: + tensor_lists = [ + g_of_f16_model, + p_f16_model, + m_of_f16_model, + v_of_f16_model, + p_main_of_f16_model, + ] + apply_multi_tensor_adam( + self.multi_tensor_adam_capturable_master, tensor_lists, inv_scale + ) + if len(p_f32_model) > 0: + tensor_lists = [ + g_of_f32_model, + p_f32_model, + m_of_f32_model, + v_of_f32_model, + ] + apply_multi_tensor_adam( + self.multi_tensor_adam_capturable, tensor_lists, inv_scale + ) + else: + if len(p_f16_model) > 0: + tensor_lists = [g_of_f16_model, p_f16_model, m_of_f16_model, v_of_f16_model] + apply_multi_tensor_adam( + self.multi_tensor_adam_capturable, tensor_lists, inv_scale + ) + if len(p_f32_model) > 0: + tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model] + apply_multi_tensor_adam( + self.multi_tensor_adam_capturable, tensor_lists, inv_scale + ) + + elif self.master_weights: # and self.capturable=False + if len(p_f16_model) > 0: + tensor_lists = [ + g_of_f16_model, + p_f16_model, + m_of_f16_model, + v_of_f16_model, + p_main_of_f16_model, + ] + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + if len(p_fp8_model) > 0: + tensor_lists = [ + g_of_fp8_model, + p_fp8_model, + m_of_fp8_model, + v_of_fp8_model, + p_main_of_fp8_model, + scales, + amaxes, + scale_invs, + ] + apply_multi_tensor_adam(self.multi_tensor_adam_fp8, tensor_lists, out_dtype) + if len(p_f32_model) > 0: + tensor_lists = [ + g_of_f32_model, + p_f32_model, + m_of_f32_model, + v_of_f32_model, + ] + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + else: # self.master_weights=False and self.capturable=False + if len(p_f16_model) > 0: + tensor_lists = [g_of_f16_model, p_f16_model, m_of_f16_model, v_of_f16_model] + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + if len(p_f32_model) > 0: + tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model] + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) return loss diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 3632d2f367..4fb8a28857 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -329,25 +329,22 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: return False # sk must be 16 ~ 16384 if sk % 8 != 0: return False # sk must be divisor of 8 - if self.attn_mask_type == "arbitrary": - return False # Custom masks not supported - + if sq == 1: + return False # sq must be > 1 if self.attn_mask_type == "causal" and sq != sk: return False # Fused causal kernel only support causal_bottom_right if ( sq % 4 == 0 # sq must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 - and self.attn_mask_type != "arbitrary" # Custom masks not supported ): batch_per_block = self.get_batch_per_block(int(sk)) - - if self.attn_mask_type == "padding": + if "padding" in self.attn_mask_type or self.attn_mask_type == "arbitrary": if ( mask is not None and sq % batch_per_block == 0 - and mask.shape[-2] == sq - and mask.shape[-1] == sk + and mask.shape[0] in [1, b] + and mask.shape[1:] == (1, sq, sk) ): return True else: @@ -358,13 +355,21 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: def forward_fused_softmax( self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None ) -> torch.Tensor: - """Fused masked softmax kernel""" + """ + Fused masked softmax path. + attn_mask_type | module + ----------------------------------------------------------------------------------------- + no_mask | ScaledSoftmax + causal (self-attention), causal_bottom_right | ScaledAlignedCausalMaskedSoftmax + padding, padding_causal, padding_causal_bottom_right | ScaledMaskedSoftmax + arbitrary ([1, 1, sq, sk] or [b, 1, sq, sk]) | ScaledMaskedSoftmax + """ scale = 1.0 if scale is None else scale - if "causal" in self.attn_mask_type: + if self.attn_mask_type in ["causal", "causal_bottom_right"]: return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) - # input is 4D tensor (b, np, sq, sk) + # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk) if mask is not None and self.attn_mask_type != "no_mask": return ScaledMaskedSoftmax.apply(inp, mask, scale) return ScaledSoftmax.apply(inp, scale) @@ -379,13 +384,19 @@ def forward_torch_softmax( if scale is not None: inp = inp * scale - if "causal" in self.attn_mask_type: + if self.attn_mask_type in ["causal", "causal_bottom_right"]: seq_len_q, seq_len_k = inp.size(2), inp.size(3) if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: assert self.kvcache_max_seq >= seq_len_k - mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask) + causal_mask = _get_onnx_export_causal_mask( + seq_len_q, seq_len_k, self.onnx_causal_mask + ) + else: + causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + if mask is None: + mask = causal_mask else: - mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + mask = torch.logical_or(mask, causal_mask) mask_output = inp if mask is not None and self.attn_mask_type != "no_mask": diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 130cf91f0e..bd6e27594d 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -487,6 +487,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -500,13 +501,16 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. + cp_comm_type : str + inter-gpu communication type for context parallelism. + Can be "p2p" or "all_gather". """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: continue if hasattr(child, "set_context_parallel_group"): - child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream) + child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type) def forward( self, @@ -525,6 +529,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, ) -> torch.Tensor: """ @@ -600,11 +608,23 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + max_seqlen_q: Optional[int], default = `None` + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q` if not provided. + max_seqlen_kv: Optional[int], default = `None` + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. + to efficiently calculate and store the context during inference. """ if self_attn_mask_type is None: @@ -652,7 +672,7 @@ def forward( hidden_states, attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, - window_size=enc_dec_window_size, + window_size=window_size, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -660,6 +680,10 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, fast_zero_fill=fast_zero_fill, ) @@ -679,6 +703,8 @@ def forward( inter_attention_outputs = self.inter_attention( hidden_states, attention_mask=enc_dec_attn_mask, + attn_mask_type=enc_dec_attn_mask_type, + window_size=enc_dec_window_size, encoder_output=encoder_output, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention,