From eff882d1a778bd5c89c65869432ce6f720a1737c Mon Sep 17 00:00:00 2001 From: gdeng Date: Tue, 28 Jan 2025 10:55:13 -0800 Subject: [PATCH] add nvshmem based api support Signed-off-by: gdeng --- build_tools/pytorch.py | 15 +++ setup.py | 6 + transformer_engine/common/CMakeLists.txt | 9 ++ .../common/nvshmem_api/CMakeLists.txt | 31 +++++ .../common/nvshmem_api/nvshmem_waitkernel.cu | 49 +++++++ .../common/nvshmem_api/nvshmem_waitkernel.h | 13 ++ transformer_engine/pytorch/csrc/extensions.h | 13 ++ .../pytorch/csrc/extensions/nvshmem_comm.cpp | 120 ++++++++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 15 +++ 9 files changed, 271 insertions(+) create mode 100644 transformer_engine/common/nvshmem_api/CMakeLists.txt create mode 100644 transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu create mode 100644 transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h create mode 100644 transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index f060e99dff..26dae3b8dd 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -90,6 +90,19 @@ def setup_pytorch_extension( cxx_flags.append("-DNVTE_UB_WITH_MPI") nvcc_flags.append("-DNVTE_UB_WITH_MPI") + library_dirs = [] + libraries = [] + if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): + assert ( + os.getenv("NVSHMEM_HOME") is not None + ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1" + nvshmem_home = Path(os.getenv("NVSHMEM_HOME")) + include_dirs.append(nvshmem_home / "include") + library_dirs.append(nvshmem_home / "lib") + libraries.append("nvshmem_host") + cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM") + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] @@ -103,4 +116,6 @@ def setup_pytorch_extension( "cxx": cxx_flags, "nvcc": nvcc_flags, }, + libraries=[str(lib) for lib in libraries], + library_dirs=[str(lib_dir) for lib_dir in library_dirs], ) diff --git a/setup.py b/setup.py index 643dd7a908..ea403d340e 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,12 @@ def setup_common_extension() -> CMakeExtension: ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))): + assert ( + os.getenv("NVSHMEM_HOME") is not None + ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1" + cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON") + if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3afddcc48d..2e69f07b00 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -91,6 +91,8 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") + + # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas @@ -108,6 +110,13 @@ if (NVTE_UB_WITH_MPI) target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) endif() +option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF) +if (NVTE_ENABLE_NVSHMEM) + add_subdirectory(nvshmem_api) + target_link_libraries(transformer_engine PUBLIC nvshmemapi) + target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/nvshmem_api/CMakeLists.txt b/transformer_engine/common/nvshmem_api/CMakeLists.txt new file mode 100644 index 0000000000..fca9490363 --- /dev/null +++ b/transformer_engine/common/nvshmem_api/CMakeLists.txt @@ -0,0 +1,31 @@ +########################################################################## +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +########################################################################## +cmake_minimum_required (VERSION 3.18) +project(nvshmemapi LANGUAGES CXX CUDA) + +# Configure dependencies +find_package(CUDAToolkit REQUIRED) +# find_package(MPI REQUIRED) +set(NVSHMEM_HOME "$ENV{NVSHMEM_HOME}" CACHE STRING "Location of NVSHMEM installation") + +add_library(nvshmemapi SHARED nvshmem_waitkernel.cu) +set(NVSHMEMAPI_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" PARENT_SCOPE) +target_link_directories(nvshmemapi PRIVATE ${NVSHMEM_HOME}/lib) +target_link_libraries(nvshmemapi PRIVATE -static-libstdc++ nvshmem_device nvshmem_host CUDA::nvml CUDA::cublas CUDA::cuda_driver) +target_include_directories(nvshmemapi PRIVATE + ${NVSHMEM_HOME}/include/) +target_include_directories(nvshmemapi PUBLIC + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + "${CMAKE_CURRENT_SOURCE_DIR}") + +set_target_properties(nvshmemapi PROPERTIES + CUDA_STANDARD 17 + CUDA_RESOLVE_DEVICE_SYMBOLS ON + POSITION_INDEPENDENT_CODE ON + CUDA_SEPARABLE_COMPILATION ON) + +# This means nvshmemapi.so will be installed in TransformerEngine/, alongside libtransformer_engine.so and transformer_engine_common.cpython-310-x86_64-linux-gnu.so +install(TARGETS nvshmemapi DESTINATION ./) \ No newline at end of file diff --git a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu new file mode 100644 index 0000000000..16a6edb09f --- /dev/null +++ b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu @@ -0,0 +1,49 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "nvshmem_waitkernel.h" + +namespace transformer_engine { +__global__ void __launch_bounds__(1) + wait_until_on_stream_and_reset(uint64_t* wait_flag, uint64_t wait_value, + uint64_t signal_reset) { + nvshmem_uint64_wait_until(wait_flag, NVSHMEM_CMP_EQ, wait_value); + *wait_flag = signal_reset; +} +void nvshmem_wait_on_stream(uint64_t* sig_addr, int wait_kind, cudaStream_t stream) { + uint64_t wait_value = 1; + uint64_t signal_reset = 0; + cudaStream_t cur_stream = stream; + + assert(wait_kind <= 2); + + if (wait_kind == 0) { + wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset); + } else if (wait_kind == 1) { + nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream); + cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, + CU_STREAM_WRITE_VALUE_DEFAULT); + } else if (wait_kind == 2) { + cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value, + CU_STREAM_WAIT_VALUE_GEQ); + // Reset local flag to 0 + cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, + CU_STREAM_WRITE_VALUE_DEFAULT); + } +} + +} // namespace transformer_engine diff --git a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h new file mode 100644 index 0000000000..286191fd92 --- /dev/null +++ b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h @@ -0,0 +1,13 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H +#define TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H + +namespace transformer_engine { +void nvshmem_wait_on_stream(uint64_t* sig_addr, int wait_kind, cudaStream_t stream); +} +#endif // TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 67fd1caf5b..bc9a6c2471 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -512,6 +512,19 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +/*************************************************************************************************** + * NVSHMEM APIs + **************************************************************************************************/ + +namespace nvshmem_api { +void init_nvshmem_backend(c10d::ProcessGroup *process_group); +torch::Tensor create_nvshmem_tensor(const std::vector &shape, c10::ScalarType dtype); +void nvshmem_send_on_stream(torch::Tensor src, torch::Tensor dst, int peer, torch::Tensor signal); +void nvshmem_wait_on_stream(torch::Tensor signal, int wait_kind); +void nvshmem_finalize(); +void nvshmem_quiet(); +} // namespace nvshmem_api + /*************************************************************************************************** * Comm+GEMM Overlap Wrappers **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp b/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp new file mode 100644 index 0000000000..2c294945e5 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp @@ -0,0 +1,120 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" + +#ifdef NVTE_ENABLE_NVSHMEM +#include +#include +#include +#endif + +#include +#include +#include +#include + +namespace nvshmem_api { +void init_nvshmem_backend(c10d::ProcessGroup *process_group) { +#ifdef NVTE_ENABLE_NVSHMEM + nvshmemx_init_attr_t attr = {}; + nvshmemx_uniqueid_t id = {}; + + int my_rank = process_group->getRank(); + int num_ranks = process_group->getSize(); + if (my_rank == 0) { + nvshmemx_get_uniqueid(&id); + } + + auto backend_is_nccl = (process_group->getBackendType() == c10d::ProcessGroup::BackendType::NCCL); + NVTE_CHECK(backend_is_nccl, "Currently only support NCCL boostrap for NVSHMEM"); + auto datatensor = torch::from_blob( + (void *)&id, {static_cast(sizeof(nvshmemx_uniqueid_t) / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto datatmp = (backend_is_nccl) ? datatensor.cuda() : datatensor; + + c10d::BroadcastOptions bcast_opts; + bcast_opts.rootRank = 0; + std::vector datachunk = {datatmp}; + auto work = process_group->broadcast(datachunk, bcast_opts); + work->wait(); + + if (backend_is_nccl) { + datatensor.copy_(datatmp.cpu()); + datatmp = torch::Tensor(); + } + + nvshmemx_set_attr_uniqueid_args(my_rank, num_ranks, &id, &attr); + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + + assert(my_rank == nvshmem_my_pe()); + assert(num_ranks == nvshmem_n_pes()); +#else + NVTE_ERROR("Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); +#endif +} + +void nvshmem_wait_on_stream(torch::Tensor signal, int wait_kind) { +#ifdef NVTE_ENABLE_NVSHMEM + uint64_t *sig_addr = (uint64_t *)signal.data_ptr(); + cudaStream_t cur_stream = (cudaStream_t)at::cuda::getCurrentCUDAStream(); + + transformer_engine::nvshmem_wait_on_stream(sig_addr, wait_kind, cur_stream); +#else + NVTE_ERROR("Internal TE error: nvshmem_wait_on_stream cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); +#endif +} + +torch::Tensor create_nvshmem_tensor(const std::vector &shape, c10::ScalarType dtype) { +#ifdef NVTE_ENABLE_NVSHMEM + auto option_gpu = + at::TensorOptions().dtype(dtype).device(at::kCUDA).device_index(c10::cuda::current_device()); + auto size = torch::elementSize(dtype) * + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); + return at::from_blob( + nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu); +#else + NVTE_ERROR("Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); +#endif +} + +void nvshmem_send_on_stream(torch::Tensor src, torch::Tensor dst, int peer, torch::Tensor signal) { +#ifdef NVTE_ENABLE_NVSHMEM + void *src_ptr = (void *)src.data_ptr(); + void *dst_ptr = (void *)dst.data_ptr(); + uint64_t *sig_addr = (uint64_t *)signal.data_ptr(); + auto nelement = src.numel() * src.element_size(); + uint64_t sigval = 1; + at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream(); + + nvshmemx_putmem_signal_on_stream(dst_ptr, src_ptr, nelement, sig_addr, sigval, NVSHMEM_SIGNAL_SET, + peer, (cudaStream_t)cur_stream); +#else + NVTE_ERROR("Internal TE error: nvshmem_send_on_stream cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); +#endif +} +void nvshmem_finalize() { +#ifdef NVTE_ENABLE_NVSHMEM + nvshmem_finalize(); +#else + NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); +#endif +} + +void nvshmem_quiet() { +#ifdef NVTE_ENABLE_NVSHMEM + nvshmem_quiet(); +#else + NVTE_ERROR("Internal TE error: nvshmem_quiet cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); +#endif +} +} // namespace nvshmem_api diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 165855d430..169e6e7b0f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -199,6 +199,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Generate partitioned indices for inputs in THD format", py::call_guard()); + // nvshmem functions + m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend, "Init nvshmem with helper", + py::call_guard()); + m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor, "Create nvshmem tensor", + py::call_guard()); + m.def("nvshmem_send_on_stream", &nvshmem_api::nvshmem_send_on_stream, + "Send on stream using nvshmem backend", py::call_guard()); + m.def("nvshmem_wait_on_stream", &nvshmem_api::nvshmem_wait_on_stream, + "Wait on stream using nvshmem backend", py::call_guard()); + m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize, "Tear down nvshmem backend", + py::call_guard()); + m.def("nvshmem_quiet", &nvshmem_api::nvshmem_quiet, + "Ensure completion of all operations by the call src", + py::call_guard()); + // multi-tensor functions m.def("multi_tensor_scale", &multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors",