Skip to content

Introduce NVSHMEM based communication API for pytorch #1430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ def setup_pytorch_extension(
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")

library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
include_dirs.append(nvshmem_home / "include")
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand All @@ -102,4 +115,6 @@ def setup_pytorch_extension(
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def setup_common_extension() -> CMakeExtension:
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")



# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
Expand All @@ -114,6 +116,13 @@ if (NVTE_UB_WITH_MPI)
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()

option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF)
if (NVTE_ENABLE_NVSHMEM)
add_subdirectory(nvshmem_api)
target_link_libraries(transformer_engine PUBLIC nvshmemapi)
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
endif()

# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)

Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/common/libtransformer_engine.version
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore*
*transformer_engine::CommOverlapCore*;
*nvshmem_wait_on_stream*;
*nvshmemi_init_thread*
};
local: *;
};
27 changes: 27 additions & 0 deletions transformer_engine/common/nvshmem_api/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
##########################################################################
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
##########################################################################
cmake_minimum_required (VERSION 3.18)
project(nvshmemapi LANGUAGES CXX CUDA)

# Configure dependencies
find_package(CUDAToolkit REQUIRED)
# find_package(MPI REQUIRED)
set(NVSHMEM_HOME "$ENV{NVSHMEM_HOME}" CACHE STRING "Location of NVSHMEM installation")

add_library(nvshmemapi STATIC nvshmem_waitkernel.cu)
set(NVSHMEMAPI_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" PARENT_SCOPE)
target_link_directories(nvshmemapi PUBLIC ${NVSHMEM_HOME}/lib)
target_link_libraries(nvshmemapi PUBLIC -static-libstdc++ nvshmem_device nvshmem_host CUDA::nvml CUDA::cublas CUDA::cuda_driver)
target_include_directories(nvshmemapi PRIVATE
${NVSHMEM_HOME}/include/)
target_include_directories(nvshmemapi PUBLIC
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
"${CMAKE_CURRENT_SOURCE_DIR}")

set_target_properties(nvshmemapi PROPERTIES
CUDA_STANDARD 17
POSITION_INDEPENDENT_CODE ON
CUDA_SEPARABLE_COMPILATION ON)
51 changes: 51 additions & 0 deletions transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cuda.h>
#include <cuda_bf16.h>
#include <nvshmem.h>

#include <cstdio>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <sstream>
#include <string>

#include "../util/logging.h"
#include "nvshmem_waitkernel.h"

__global__ void __launch_bounds__(1)
wait_until_on_stream_and_reset(uint64_t* wait_flag, uint64_t wait_value,
uint64_t signal_reset) {
nvshmem_uint64_wait_until(wait_flag, NVSHMEM_CMP_EQ, wait_value);
*wait_flag = signal_reset;
}
void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream) {
uint64_t wait_value = 1;
uint64_t signal_reset = 0;
cudaStream_t cur_stream = stream;

NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && wait_kind <= WaitKind::STREAM_WAIT,
"Invalid wait kind: ", static_cast<int>(wait_kind));

switch (wait_kind) {
case WaitKind::KERNEL_WAIT:
wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset);
break;
case WaitKind::NVSHMEM_WAIT:
nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
break;
case WaitKind::STREAM_WAIT:
cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value,
CU_STREAM_WAIT_VALUE_GEQ);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
break;
}
}
38 changes: 38 additions & 0 deletions transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
#define TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H

#ifdef __cplusplus
#include <cstdint>
extern "C" {
#else
#include <stdint.h>
#endif

/*! \enum WaitKind
* \brief Types of wait operations that can be performed.
*/
enum class WaitKind {
KERNEL_WAIT = 0, /*!< Wait using a CUDA kernel */
NVSHMEM_WAIT = 1, /*!< Wait using NVSHMEM wait operation */
STREAM_WAIT = 2 /*!< Wait using CUDA stream synchronization */
};

/*! \brief Wait on a signal until a certain condition is met.
*
* \param[in] sig_addr The address of the signal to wait on.
* \param[in] wait_kind The kind of wait to perform.
* \param[in] stream The stream to wait on.
*/
void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif

#endif // TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
17 changes: 17 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,23 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list);

/***************************************************************************************************
* NVSHMEM APIs
**************************************************************************************************/

namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group);

torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);

void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal);

void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind);

void nvshmem_finalize();
} // namespace nvshmem_api

/***************************************************************************************************
* swizzle
**************************************************************************************************/
Expand Down
129 changes: 129 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../extensions.h"

#ifdef NVTE_ENABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmem_api/nvshmem_waitkernel.h>
#include <nvshmemx.h>
#endif

#include <cuda.h>
#include <cuda_fp8.h>
#include <torch/cuda.h>
#include <torch/extension.h>

namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {};
nvshmemx_uniqueid_t id = {};

int my_rank = process_group->getRank();
int num_ranks = process_group->getSize();
if (my_rank == 0) {
nvshmemx_get_uniqueid(&id);
}

auto backend_is_nccl = (process_group->getBackendType() == c10d::ProcessGroup::BackendType::NCCL);
NVTE_CHECK(backend_is_nccl, "Currently only support NCCL boostrap for NVSHMEM");
auto datatensor =
torch::from_blob(reinterpret_cast<void *>(&id),
{static_cast<int64_t>(sizeof(nvshmemx_uniqueid_t) / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto datatmp = (backend_is_nccl) ? datatensor.cuda() : datatensor;

c10d::BroadcastOptions bcast_opts;
bcast_opts.rootRank = 0;
std::vector<torch::Tensor> datachunk = {datatmp};
auto work = process_group->broadcast(datachunk, bcast_opts);
work->wait();

if (backend_is_nccl) {
datatensor.copy_(datatmp.cpu());
datatmp = torch::Tensor();
}

nvshmemx_set_attr_uniqueid_args(my_rank, num_ranks, &id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);

NVTE_CHECK(my_rank == nvshmem_my_pe(), "my_rank: ", my_rank,
" != nvshmem_my_pe(): ", nvshmem_my_pe());
NVTE_CHECK(num_ranks == nvshmem_n_pes(), "num_ranks: ", num_ranks,
" != nvshmem_n_pes(): ", nvshmem_n_pes());
#else
NVTE_ERROR("Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}

void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind) {
#ifdef NVTE_ENABLE_NVSHMEM
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
cudaStream_t cur_stream = (cudaStream_t)at::cuda::getCurrentCUDAStream();

WaitKind wait_kind_enum = WaitKind::STREAM_WAIT;

if (wait_kind == "kernel") {
wait_kind_enum = WaitKind::KERNEL_WAIT;
} else if (wait_kind == "nvshmem") {
wait_kind_enum = WaitKind::NVSHMEM_WAIT;
} else if (wait_kind == "stream") {
wait_kind_enum = WaitKind::STREAM_WAIT;
} else {
NVTE_ERROR("Invalid wait kind: ", wait_kind);
}
nvshmem_wait_on_stream(sig_addr, wait_kind_enum, cur_stream);

#else
NVTE_ERROR(
"Internal TE error: nvshmem_wait_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}

torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype) {
#ifdef NVTE_ENABLE_NVSHMEM
auto option_gpu =
at::TensorOptions().dtype(dtype).device(at::kCUDA).device_index(c10::cuda::current_device());
auto size = torch::elementSize(dtype) *
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
return at::from_blob(
nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu);
#else
NVTE_ERROR("Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}

void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal) {
#ifdef NVTE_ENABLE_NVSHMEM
void *src_ptr = reinterpret_cast<void *>(src.data_ptr());
void *dst_ptr = reinterpret_cast<void *>(dst.data_ptr());
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
auto nelement = src.numel() * src.element_size();
uint64_t sigval = 1;
at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream();

nvshmemx_putmem_signal_on_stream(dst_ptr, src_ptr, nelement, sig_addr, sigval, NVSHMEM_SIGNAL_SET,
peer, (cudaStream_t)cur_stream);
#else
NVTE_ERROR(
"Internal TE error: nvshmem_send_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_finalize() {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmem_finalize();
#else
NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
} // namespace nvshmem_api
17 changes: 17 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());

// nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());

// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors",
Expand Down