Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce NVSHMEM based communication API for pytorch #1430

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def setup_pytorch_extension(
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")

library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
include_dirs.append(nvshmem_home / "include")
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand All @@ -103,4 +116,6 @@ def setup_pytorch_extension(
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def setup_common_extension() -> CMakeExtension:
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")



# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
Expand All @@ -108,6 +110,13 @@ if (NVTE_UB_WITH_MPI)
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()

option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF)
if (NVTE_ENABLE_NVSHMEM)
add_subdirectory(nvshmem_api)
target_link_libraries(transformer_engine PUBLIC nvshmemapi)
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
endif()

# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)

Expand Down
31 changes: 31 additions & 0 deletions transformer_engine/common/nvshmem_api/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
##########################################################################
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
##########################################################################
cmake_minimum_required (VERSION 3.18)
project(nvshmemapi LANGUAGES CXX CUDA)

# Configure dependencies
find_package(CUDAToolkit REQUIRED)
# find_package(MPI REQUIRED)
set(NVSHMEM_HOME "$ENV{NVSHMEM_HOME}" CACHE STRING "Location of NVSHMEM installation")

add_library(nvshmemapi SHARED nvshmem_waitkernel.cu)
set(NVSHMEMAPI_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" PARENT_SCOPE)
target_link_directories(nvshmemapi PRIVATE ${NVSHMEM_HOME}/lib)
target_link_libraries(nvshmemapi PRIVATE -static-libstdc++ nvshmem_device nvshmem_host CUDA::nvml CUDA::cublas CUDA::cuda_driver)
target_include_directories(nvshmemapi PRIVATE
${NVSHMEM_HOME}/include/)
target_include_directories(nvshmemapi PUBLIC
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
"${CMAKE_CURRENT_SOURCE_DIR}")

set_target_properties(nvshmemapi PROPERTIES
CUDA_STANDARD 17
CUDA_RESOLVE_DEVICE_SYMBOLS ON
POSITION_INDEPENDENT_CODE ON
CUDA_SEPARABLE_COMPILATION ON)

# This means nvshmemapi.so will be installed in TransformerEngine/, alongside libtransformer_engine.so and transformer_engine_common.cpython-310-x86_64-linux-gnu.so
install(TARGETS nvshmemapi DESTINATION ./)
49 changes: 49 additions & 0 deletions transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cuda.h>
#include <cuda_bf16.h>
#include <nvshmem.h>

#include <cstdio>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <sstream>
#include <string>

#include "nvshmem_waitkernel.h"

namespace transformer_engine {
__global__ void __launch_bounds__(1)
wait_until_on_stream_and_reset(uint64_t* wait_flag, uint64_t wait_value,
uint64_t signal_reset) {
nvshmem_uint64_wait_until(wait_flag, NVSHMEM_CMP_EQ, wait_value);
*wait_flag = signal_reset;
}
void nvshmem_wait_on_stream(uint64_t* sig_addr, int wait_kind, cudaStream_t stream) {
uint64_t wait_value = 1;
uint64_t signal_reset = 0;
cudaStream_t cur_stream = stream;

assert(wait_kind <= 2);

if (wait_kind == 0) {
wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset);
} else if (wait_kind == 1) {
nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
} else if (wait_kind == 2) {
cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value,
CU_STREAM_WAIT_VALUE_GEQ);
// Reset local flag to 0
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
}
}

} // namespace transformer_engine
13 changes: 13 additions & 0 deletions transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
#define TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H

namespace transformer_engine {
void nvshmem_wait_on_stream(uint64_t* sig_addr, int wait_kind, cudaStream_t stream);
}
#endif // TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
13 changes: 13 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,19 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list);

/***************************************************************************************************
* NVSHMEM APIs
**************************************************************************************************/

namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group);
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
void nvshmem_send_on_stream(torch::Tensor src, torch::Tensor dst, int peer, torch::Tensor signal);
void nvshmem_wait_on_stream(torch::Tensor signal, int wait_kind);
void nvshmem_finalize();
void nvshmem_quiet();
} // namespace nvshmem_api

/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
Expand Down
120 changes: 120 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../extensions.h"

#ifdef NVTE_ENABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmem_api/nvshmem_waitkernel.h>
#include <nvshmemx.h>
#endif

#include <cuda.h>
#include <cuda_fp8.h>
#include <torch/cuda.h>
#include <torch/extension.h>

namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {};
nvshmemx_uniqueid_t id = {};

int my_rank = process_group->getRank();
int num_ranks = process_group->getSize();
if (my_rank == 0) {
nvshmemx_get_uniqueid(&id);
}

auto backend_is_nccl = (process_group->getBackendType() == c10d::ProcessGroup::BackendType::NCCL);
NVTE_CHECK(backend_is_nccl, "Currently only support NCCL boostrap for NVSHMEM");
auto datatensor = torch::from_blob(
(void *)&id, {static_cast<int64_t>(sizeof(nvshmemx_uniqueid_t) / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto datatmp = (backend_is_nccl) ? datatensor.cuda() : datatensor;

c10d::BroadcastOptions bcast_opts;
bcast_opts.rootRank = 0;
std::vector<torch::Tensor> datachunk = {datatmp};
auto work = process_group->broadcast(datachunk, bcast_opts);
work->wait();

if (backend_is_nccl) {
datatensor.copy_(datatmp.cpu());
datatmp = torch::Tensor();
}

nvshmemx_set_attr_uniqueid_args(my_rank, num_ranks, &id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);

assert(my_rank == nvshmem_my_pe());
assert(num_ranks == nvshmem_n_pes());
#else
NVTE_ERROR("Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}

void nvshmem_wait_on_stream(torch::Tensor signal, int wait_kind) {
#ifdef NVTE_ENABLE_NVSHMEM
uint64_t *sig_addr = (uint64_t *)signal.data_ptr();
cudaStream_t cur_stream = (cudaStream_t)at::cuda::getCurrentCUDAStream();

transformer_engine::nvshmem_wait_on_stream(sig_addr, wait_kind, cur_stream);
#else
NVTE_ERROR("Internal TE error: nvshmem_wait_on_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}

torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype) {
#ifdef NVTE_ENABLE_NVSHMEM
auto option_gpu =
at::TensorOptions().dtype(dtype).device(at::kCUDA).device_index(c10::cuda::current_device());
auto size = torch::elementSize(dtype) *
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
return at::from_blob(
nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu);
#else
NVTE_ERROR("Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}

void nvshmem_send_on_stream(torch::Tensor src, torch::Tensor dst, int peer, torch::Tensor signal) {
#ifdef NVTE_ENABLE_NVSHMEM
void *src_ptr = (void *)src.data_ptr();
void *dst_ptr = (void *)dst.data_ptr();
uint64_t *sig_addr = (uint64_t *)signal.data_ptr();
auto nelement = src.numel() * src.element_size();
uint64_t sigval = 1;
at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream();

nvshmemx_putmem_signal_on_stream(dst_ptr, src_ptr, nelement, sig_addr, sigval, NVSHMEM_SIGNAL_SET,
peer, (cudaStream_t)cur_stream);
#else
NVTE_ERROR("Internal TE error: nvshmem_send_on_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_finalize() {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmem_finalize();
#else
NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}

void nvshmem_quiet() {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmem_quiet();
#else
NVTE_ERROR("Internal TE error: nvshmem_quiet cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
} // namespace nvshmem_api
15 changes: 15 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());

// nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend, "Init nvshmem with helper",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor, "Create nvshmem tensor",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_stream", &nvshmem_api::nvshmem_send_on_stream,
"Send on stream using nvshmem backend", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_stream", &nvshmem_api::nvshmem_wait_on_stream,
"Wait on stream using nvshmem backend", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize, "Tear down nvshmem backend",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_quiet", &nvshmem_api::nvshmem_quiet,
"Ensure completion of all operations by the call src",
py::call_guard<py::gil_scoped_release>());

// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors",
Expand Down