diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 64f41d65fae4..e9f585490182 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -54,10 +54,8 @@ if(USE_CUDA) list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS}) list(APPEND COMPILER_SRCS src/target/opt/build_cuda_on.cc) - list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY}) - list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_NVRTC_LIBRARY}) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if(CMAKE_VERSION VERSION_LESS "3.24") diff --git a/cmake/utils/FindCUDA.cmake b/cmake/utils/FindCUDA.cmake index c4c18eef0f80..c62506cf4144 100644 --- a/cmake/utils/FindCUDA.cmake +++ b/cmake/utils/FindCUDA.cmake @@ -33,7 +33,6 @@ # - CUDA_TOOLKIT_ROOT_DIR # - CUDA_CUDA_LIBRARY # - CUDA_CUDART_LIBRARY -# - CUDA_NVRTC_LIBRARY # - CUDA_CUDNN_INCLUDE_DIRS # - CUDA_CUDNN_LIBRARY # - CUDA_CUBLAS_LIBRARY @@ -64,9 +63,6 @@ macro(find_cuda use_cuda use_cudnn) find_library(CUDA_CUDA_LIBRARY cuda ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32) - find_library(CUDA_NVRTC_LIBRARY nvrtc - ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 - ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32) find_library(CUDA_CUBLAS_LIBRARY cublas ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32) @@ -81,10 +77,6 @@ macro(find_cuda use_cuda use_cudnn) if(_CUDA_CUDA_LIBRARY) set(CUDA_CUDA_LIBRARY ${_CUDA_CUDA_LIBRARY}) endif() - find_library(CUDA_NVRTC_LIBRARY nvrtc - PATHS ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu - NO_DEFAULT_PATH) find_library(CUDA_CURAND_LIBRARY curand PATHS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu @@ -140,7 +132,6 @@ macro(find_cuda use_cuda use_cudnn) message(STATUS "Found CUDA_TOOLKIT_ROOT_DIR=" ${CUDA_TOOLKIT_ROOT_DIR}) message(STATUS "Found CUDA_CUDA_LIBRARY=" ${CUDA_CUDA_LIBRARY}) message(STATUS "Found CUDA_CUDART_LIBRARY=" ${CUDA_CUDART_LIBRARY}) - message(STATUS "Found CUDA_NVRTC_LIBRARY=" ${CUDA_NVRTC_LIBRARY}) message(STATUS "Found CUDA_CUDNN_INCLUDE_DIRS=" ${CUDA_CUDNN_INCLUDE_DIRS}) message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY}) message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY}) diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 1295c679d778..a72bd60fd77f 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -60,6 +60,9 @@ RUN bash /install/ubuntu_install_opencl.sh COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh +COPY install/ubuntu_install_cuda_python.sh /install/ubuntu_install_cuda_python.sh +RUN bash /install/ubuntu_install_cuda_python.sh + COPY install/ubuntu_install_sphinx.sh /install/ubuntu_install_sphinx.sh RUN bash /install/ubuntu_install_sphinx.sh diff --git a/docker/install/ubuntu_install_cuda_python.sh b/docker/install/ubuntu_install_cuda_python.sh new file mode 100644 index 000000000000..eb4efac5c050 --- /dev/null +++ b/docker/install/ubuntu_install_cuda_python.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +pip3 install cuda-python diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index d062714938d6..edf3e8af4f00 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -16,14 +16,18 @@ # under the License. # pylint: disable=invalid-name """Utility to invoke nvcc compiler in the system""" + from __future__ import absolute_import as _abs +import glob import os +import platform import subprocess import warnings from typing import Tuple import tvm_ffi + import tvm from tvm.target import Target @@ -31,8 +35,10 @@ from . import utils -def compile_cuda(code, target_format=None, arch=None, options=None, path_target=None): - """Compile cuda code with NVCC from env. +def compile_cuda( + code, target_format=None, arch=None, options=None, path_target=None, compiler="nvcc" +): + """Compile cuda code with NVCC or NVRTC. Parameters ---------- @@ -40,7 +46,7 @@ def compile_cuda(code, target_format=None, arch=None, options=None, path_target= The cuda code. target_format : str - The target format of nvcc compiler. + The target format of the compiler ("ptx", "cubin", or "fatbin"). arch : str The cuda architecture. @@ -51,14 +57,61 @@ def compile_cuda(code, target_format=None, arch=None, options=None, path_target= path_target : str, optional Output file. - Return - ------ - cubin : bytearray - The bytearray of the cubin + compiler : str, optional + Compiler backend: "nvcc" or "nvrtc". + This can be set by the TVM_CUDA_COMPILE_MODE environment variable. + + Returns + ------- + res_binary : bytearray + The bytearray of the compiled binary (ptx/cubin/fatbin). + + Notes + ----- + - NVRTC is a "runtime" compilation library and can be faster for JIT compilation. + - NVRTC requires cuda-python: pip install cuda-python + """ + # TODO: if need NVSHMEM for compilation, fall back to NVCC because support for NVRTC + # is not yet implemented + use_nvshmem = "#include " in code or "#include " in code + if compiler == "nvcc" or use_nvshmem: + return _compile_cuda_nvcc(code, target_format, arch, options, path_target, use_nvshmem) + elif compiler == "nvrtc": + return _compile_cuda_nvrtc(code, target_format, arch, options) + else: + raise ValueError(f"cuda compiler must be 'nvcc' or 'nvrtc', got: {compiler}") + + +def _compile_cuda_nvcc( + code, + target_format=None, + arch=None, + options=None, + path_target=None, + use_nvshmem=False, +): + """Compile CUDA code using nvcc. + + Parameters + ---------- + code : str + The CUDA source code. + target_format : str, optional + Output format: "ptx", "cubin", or "fatbin". + arch : str, optional + Target architecture. Auto-detected if None. + options : str or list of str, optional + Additional nvcc options. + path_target : str, optional + Output file path. + + Returns + ------- + bytearray + Compiled binary data. """ # Check for NVSHMEM dependency nvshmem_include_path, nvshmem_lib_path = None, None - use_nvshmem = "#include " in code or "#include " in code if use_nvshmem: # NOTE: we cannot check whether nvshmem is used based on whether # the global function "runtime.nvshmem.cumodule_init" is defined. @@ -106,8 +159,9 @@ def compile_cuda(code, target_format=None, arch=None, options=None, path_target= file_target = path_target if path_target else temp_target if use_nvshmem: - file_prefix = file_target.split(".")[0] + file_prefix = os.path.splitext(file_target)[0] file_target = f"{file_prefix}.o" # in the first stage, compile to object file + cmd = ["nvcc"] cmd += [f"--{target_format}", "-O3"] if kernels_output_dir is not None: @@ -151,14 +205,11 @@ def compile_cuda(code, target_format=None, arch=None, options=None, path_target= msg += py_str(out) raise RuntimeError(msg) - # start second stage of compilation + # Second stage for NVSHMEM if use_nvshmem: cmd = ["nvlink"] cmd += [f"-arch=sm_{compute_version}"] - cmd += [ - "-L", - nvshmem_lib_path, - ] + cmd += ["-L", nvshmem_lib_path] cmd += ["-L", os.path.join(find_cuda_path(), "lib64")] cmd += ["-l", "nvshmem_device"] cmd += ["-l", "cudadevrt"] @@ -184,6 +235,187 @@ def compile_cuda(code, target_format=None, arch=None, options=None, path_target= return data +def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None): + """Compile CUDA code using NVRTC (NVIDIA Runtime Compilation). + + Parameters + ---------- + code : str + The CUDA source code. + target_format : str, optional + Output format: "cubin" or "ptx". Default: "cubin" + arch : str, optional + Target architecture (e.g., "sm_80"). Auto-detected if None. + options : str or list of str, optional + Additional NVRTC options. + + Returns + ------- + bytearray + Compiled binary data. + """ + try: + from cuda.bindings import nvrtc # pylint: disable=import-outside-toplevel + except ImportError as e: + raise RuntimeError( + "Failed to compile CUDA with NVRTC because the `cuda-python` package " + "is not available.\n" + "Please install it with: pip install cuda-python\n" + "See: https://nvidia.github.io/cuda-python/" + ) from e + + # Default target format + if target_format is None: + target_format = "cubin" + + # Validate target_format (NVRTC doesn't support fatbin) + if target_format == "fatbin": + raise ValueError( + "NVRTC does not support fatbin generation yet. " + "Use target_format='cubin' or 'ptx' with NVRTC, " + "or set compiler='nvcc' for fatbin compilation." + ) + if target_format not in ["cubin", "ptx"]: + raise ValueError(f"target_format must be 'cubin' or 'ptx', got: {target_format}") + + # Validate options + if options is not None and not isinstance(options, (str, list)): + raise ValueError("options must be str or list of str") + + # Auto-detect architecture + if arch is None: + compute_version = get_target_compute_version(Target.current(allow_none=True)) + arch = f"sm_{''.join(compute_version.split('.'))}" + + # Strip host-only headers for NVRTC. NVRTC compiles device code and does not + # require the CUDA driver header or host C++ headers. + headers_to_strip = {"#include "} + code_filtered = "\n".join( + line for line in code.splitlines() if line.strip() not in headers_to_strip + ) + + # NVRTC compiles device code and does not include the host-side cuda.h. + # CUtensorMap is a host-side structure, to reference and use it in device code, + # we must forward-declare it for NVRTC. + if "CUtensorMap" in code_filtered: + code_filtered = ( + "struct __align__(128) CUtensorMap {\n" + " unsigned long long opaque[16];\n" + "};\n\n" + code_filtered + ) + + # Create NVRTC program + # Use "tvm_kernels.cu" for consistency with nvcc path + result, prog = nvrtc.nvrtcCreateProgram( + str.encode(code_filtered), b"tvm_kernels.cu", 0, None, None + ) + if result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError(f"Failed to create NVRTC program: {nvrtc.nvrtcGetErrorString(result)}") + + # Prepare compilation options + cuda_path = find_cuda_path() + compile_opts = [ + f"--gpu-architecture={arch}".encode(), + b"-default-device", + ] + + # Add CUDA include paths. NVRTC needs explicit include paths for CUDA headers. + # Standard installations: cuda_path/include + # Conda/architecture-specific installations: cuda_path/targets//include + include_paths = [] + + # Check standard include directory + standard_include = os.path.join(cuda_path, "include") + if os.path.isdir(standard_include): + include_paths.append(standard_include) + + # Check architecture-specific include directory + arch_include = os.path.join( + cuda_path, + "targets", + f"{platform.machine()}-{platform.system().lower()}", + "include", + ) + if os.path.isdir(arch_include): + include_paths.append(arch_include) + + # Verify we can find essential CUDA headers + if not any(os.path.isfile(os.path.join(p, "cuda_runtime.h")) for p in include_paths): + raise RuntimeError( + f"Cannot find CUDA headers in {cuda_path}. " + f"Searched in: {include_paths}. " + "Please ensure CUDA is properly installed." + ) + + # Add all valid include paths + for include_path in include_paths: + compile_opts.append(f"-I{include_path}".encode()) + + compile_opts.extend( + [ + b"-U__CUDA_NO_HALF_OPERATORS__", + b"-U__CUDA_NO_HALF_CONVERSIONS__", + b"-U__CUDA_NO_BFLOAT16_OPERATORS__", + b"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + b"-U__CUDA_NO_BFLOAT162_OPERATORS__", + b"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + b"--use_fast_math", + ] + ) + + # Add user-provided options + if options: + if isinstance(options, str): + compile_opts.append(options.encode()) + else: + compile_opts.extend([opt.encode() if isinstance(opt, str) else opt for opt in options]) + + # Compile + (result,) = nvrtc.nvrtcCompileProgram(prog, len(compile_opts), compile_opts) + if result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + # Get compilation log + result_log, log_size = nvrtc.nvrtcGetProgramLogSize(prog) + if result_log == nvrtc.nvrtcResult.NVRTC_SUCCESS and log_size > 0: + log_buf = bytearray(log_size) + (result_log,) = nvrtc.nvrtcGetProgramLog(prog, log_buf) + if result_log == nvrtc.nvrtcResult.NVRTC_SUCCESS: + error_msg = f"NVRTC compilation failed:\n{log_buf.decode('utf-8')}" + else: + error_msg = f"NVRTC compilation failed (couldn't get log): {result}" + else: + error_msg = f"NVRTC compilation failed: {result}" + + nvrtc.nvrtcDestroyProgram(prog) + raise RuntimeError(error_msg) + + # Get compiled binary + if target_format == "cubin": + result, binary_size = nvrtc.nvrtcGetCUBINSize(prog) + if result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + nvrtc.nvrtcDestroyProgram(prog) + raise RuntimeError(f"Failed to get CUBIN size: {nvrtc.nvrtcGetErrorString(result)}") + binary_buf = bytearray(binary_size) + (result,) = nvrtc.nvrtcGetCUBIN(prog, binary_buf) + if result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + nvrtc.nvrtcDestroyProgram(prog) + raise RuntimeError(f"Failed to get CUBIN: {nvrtc.nvrtcGetErrorString(result)}") + else: # ptx + result, binary_size = nvrtc.nvrtcGetPTXSize(prog) + if result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + nvrtc.nvrtcDestroyProgram(prog) + raise RuntimeError(f"Failed to get PTX size: {nvrtc.nvrtcGetErrorString(result)}") + binary_buf = bytearray(binary_size) + (result,) = nvrtc.nvrtcGetPTX(prog, binary_buf) + if result != nvrtc.nvrtcResult.NVRTC_SUCCESS: + nvrtc.nvrtcDestroyProgram(prog) + raise RuntimeError(f"Failed to get PTX: {nvrtc.nvrtcGetErrorString(result)}") + + # Clean up + nvrtc.nvrtcDestroyProgram(prog) + + return bytearray(binary_buf) + + def find_cuda_path(): """Utility function to find cuda path @@ -241,7 +473,7 @@ def get_cuda_version(cuda_path=None): (out, _) = proc.communicate() out = py_str(out) if proc.returncode == 0: - release_line = [l for l in out.split("\n") if "release" in l][0] + release_line = [line for line in out.split("\n") if "release" in line][0] release_fields = [s.strip() for s in release_line.split(",")] version_str = [f[1:] for f in release_fields if f.startswith("V")][0] return tuple(int(field) for field in version_str.split(".")) @@ -280,16 +512,37 @@ def find_nvshmem_paths() -> Tuple[str, str]: unique_candidates.append(path) for root in unique_candidates: - include_path = os.path.join(root, "include") + # Check both standard include path and versioned subdirectories (e.g., nvshmem_12) + include_paths_to_check = [os.path.join(root, "include")] + + # Add versioned subdirectories like include/nvshmem_* + versioned_includes = glob.glob(os.path.join(root, "include", "nvshmem_*")) + include_paths_to_check.extend(versioned_includes) + + # Check standard and architecture-specific lib directories lib_paths_to_check = [ os.path.join(root, "lib64"), os.path.join(root, "lib"), ] - if os.path.isfile(os.path.join(include_path, "nvshmem.h")): - for lib_path in lib_paths_to_check: - if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")): - return include_path, lib_path + # Add architecture-specific lib paths (e.g., lib/x86_64-linux-gnu) + machine = platform.machine() + system = platform.system().lower() + lib_paths_to_check.extend( + [ + os.path.join(root, "lib", f"{machine}-{system}-gnu"), + os.path.join(root, "lib64", f"{machine}-{system}-gnu"), + ] + ) + + for include_path in include_paths_to_check: + if os.path.isfile(os.path.join(include_path, "nvshmem.h")): + for lib_path in lib_paths_to_check: + # Check for both static (.a) and shared (.so) libraries + if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")) or os.path.isfile( + os.path.join(lib_path, "libnvshmem.so") + ): + return include_path, lib_path error_message = [ "Error: Could not find NVSHMEM installation.", @@ -315,9 +568,39 @@ def find_nvshmem_paths() -> Tuple[str, str]: @tvm_ffi.register_global_func def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument - """use nvcc to generate fatbin code for better optimization""" - ptx = compile_cuda(code, target_format="fatbin") - return ptx + """ + Compile CUDA code using the configured backend (nvcc or nvrtc). + + This callback is invoked by TVM's C++ backend during CUDA module compilation. + By default, uses nvcc to generate fatbin. + + Environment Variables + --------------------- + TVM_CUDA_COMPILE_MODE : str + Compiler backend: "nvcc" (default) or "nvrtc" + - "nvcc": Use nvcc subprocess, generates fatbin + - "nvrtc": Use NVRTC via cuda-python for faster JIT, generates cubin + + Parameters + ---------- + code : str + CUDA source code to compile + target : Target + TVM target architecture + + Returns + ------- + bytes + Compiled binary (fatbin for nvcc, cubin for nvrtc) + """ + compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc").lower() + + if compiler == "nvrtc": + return compile_cuda(code, target_format="cubin", compiler="nvrtc") + if compiler == "nvcc": + return compile_cuda(code, target_format="fatbin", compiler="nvcc") + + raise ValueError(f"Invalid TVM_CUDA_COMPILE_MODE: {compiler}. Expected 'nvcc' or 'nvrtc'.") @tvm_ffi.register_global_func("tvm_callback_libdevice_path") diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py b/python/tvm/script/ir_builder/tir/external_kernel.py index 405e1e6cbf93..45a3d364c128 100644 --- a/python/tvm/script/ir_builder/tir/external_kernel.py +++ b/python/tvm/script/ir_builder/tir/external_kernel.py @@ -17,14 +17,15 @@ """External kernel integration fro TIR""" import json import logging +import os import tempfile from pathlib import Path from typing import Any, Dict, List, Tuple, Union from tvm import __version__ as tvm_version from tvm import tir -from tvm.runtime import Module, load_module, const from tvm.contrib import nvcc +from tvm.runtime import Module, const, load_module class BaseKernel: # pylint: disable=too-few-public-methods @@ -100,10 +101,15 @@ def __init__(self, source_code: str): self.source_code = source_code def compile_to_device_module( # pylint: disable=arguments-differ - self, grid: List[List[Union[int, tir.PrimExpr]]], *args: List[Any], **kwargs: Dict[str, Any] + self, + grid: List[List[Union[int, tir.PrimExpr]]], + *args: List[Any], + **kwargs: Dict[str, Any], ) -> Tuple[str, Module, List[Any]]: """Compile the kernel to a device module.""" - from tvm.relax.frontend.nn import SourceModule # pylint: disable=import-outside-toplevel + from tvm.relax.frontend.nn import ( # pylint: disable=import-outside-toplevel + SourceModule, + ) kernel_name = kwargs["kernel_name"] assert len(grid) == 2, ( @@ -134,8 +140,13 @@ def compile_to_device_module( # pylint: disable=arguments-differ with tempfile.TemporaryDirectory() as temp_dir: ptx_path = f"{temp_dir}/{kernel_name}.ptx" + compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc") nvcc.compile_cuda( - source_code, target_format="ptx", options=compile_options, path_target=ptx_path + source_code, + target_format="ptx", + options=compile_options, + path_target=ptx_path, + compiler=compiler, ) with open(ptx_path, "r") as f: ptx = f.read() @@ -171,7 +182,10 @@ def call_kernel( kwargs : Dict[str, Any] Additional keyword arguments to pass to the kernel or compilation. """ - from ..ir import module_get_attr, module_set_attr # pylint: disable=import-outside-toplevel + from ..ir import ( # pylint: disable=import-outside-toplevel + module_get_attr, + module_set_attr, + ) from .ir import call_packed # pylint: disable=import-outside-toplevel kernel_type = f"{type(kernel).__module__}.{type(kernel).__qualname__}" diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 3471902bc311..d682e2cae5b3 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -117,7 +118,22 @@ void NVSHMEMXCumoduleInit(void* cuModule) { // NOTE: we do not check the return value of nvshmemx_cumodule_init. // The reason is because that the input cuModule might not use any NVSHMEM functions, // in which case the nvshmemx_cumodule_init will fail. - nvshmemx_cumodule_init(mod); + + // A set of guards to check if the module has NVSHMEM symbol to avoid the + // "gpgpu named symbol not found" error. + CUdeviceptr d_ptr; + size_t d_size; + const char* kNvshmemDeviceSymbols[] = { + "nvshmemi_device_state_d", "nvshmem_i_device_state_d", + "nvshmemi_device_team_state_d", "nvshmemi_device_heap_base_d", + "nvshmemi_device_heap_size_d", "nvshmemi_device_heap_d", + }; + for (const char* sym : kNvshmemDeviceSymbols) { + if (cuModuleGetGlobal(&d_ptr, &d_size, mod, sym) == CUDA_SUCCESS) { + nvshmemx_cumodule_init(mod); + return; + } + } } } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 8d2589aaec13..88960594d065 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -28,7 +28,6 @@ #include #endif #include -#include #include @@ -40,91 +39,10 @@ namespace tvm { namespace codegen { -#define NVRTC_CALL(x) \ - { \ - nvrtcResult result = x; \ - if (result != NVRTC_SUCCESS) { \ - LOG(FATAL) << "NvrtcError: " #x " failed with error: " << nvrtcGetErrorString(result); \ - } \ - } - -std::string FindCUDAIncludePath() { -#if defined(_WIN32) - const std::string delimiter = "\\"; -#else - const std::string delimiter = "/"; -#endif - std::string cuda_include_path; - const char* cuda_path_env = std::getenv("CUDA_PATH"); - if (cuda_path_env != nullptr) { - cuda_include_path += cuda_path_env; - cuda_include_path += delimiter + "include"; - return cuda_include_path; - } - -#if defined(__linux__) - struct stat st; - cuda_include_path = "/usr/local/cuda/include"; - if (stat(cuda_include_path.c_str(), &st) == 0) { - return cuda_include_path; - } - - if (stat("/usr/include/cuda.h", &st) == 0) { - return "/usr/include"; - } -#endif - LOG(FATAL) << "Cannot find cuda include path." - << "CUDA_PATH is not set or CUDA is not installed in the default installation path." - << "In other than linux, it is necessary to set CUDA_PATH."; - return cuda_include_path; -} - -std::string NVRTCCompile(const std::string& code, bool include_path = false) { - std::vector compile_params; - std::vector param_cstrings{}; - nvrtcProgram prog; - std::string cc = "30"; - int major, minor; - cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0); - cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); - - if (e1 == cudaSuccess && e2 == cudaSuccess) { - cc = std::to_string(major) + std::to_string(minor); - } else { - LOG(WARNING) << "cannot detect compute capability from your device, " - << "fall back to compute_30."; - } - - compile_params.push_back("-arch=compute_" + cc); - - if (include_path) { - std::string include_option = "--include-path=" + FindCUDAIncludePath(); - - compile_params.push_back(include_option); - } - - for (const auto& string : compile_params) { - param_cstrings.push_back(string.c_str()); - } - NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); - nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); - - size_t log_size; - NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); - std::string log; - log.resize(log_size); - NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0])); - ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log; - size_t ptx_size; - NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size)); - - std::string ptx; - ptx.resize(ptx_size); - NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0])); - NVRTC_CALL(nvrtcDestroyProgram(&prog)); - - return ptx; -} +// Note: CUDA include path finding and NVRTC compilation are now handled +// in Python for better maintainability and to leverage cuda-python bindings. +// The C++ NVRTC code has been removed as part of the Python-first +// compilation strategy. ffi::Module BuildCUDA(IRModule mod, Target target) { bool output_ssa = false; @@ -157,20 +75,32 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { code = (*f)(code, target).cast(); } std::string fmt = "ptx"; - std::string ptx; + std::string compiled; + + // Always use Python compilation callback (nvcc or nvrtc) + // The C++ NVRTC fallback has been removed in favor of Python-first approach + auto f_compile = ffi::Function::GetGlobal("tvm_callback_cuda_compile"); + ICHECK(f_compile != nullptr) + << "tvm_callback_cuda_compile not found. " + << "Please ensure TVM Python runtime is properly initialized.\n" + << "The Python callback (tvm.contrib.nvcc.tvm_callback_cuda_compile) is required " + << "for CUDA compilation. The C++ NVRTC fallback has been removed.\n" + << "Make sure to import tvm.contrib.nvcc in your Python code."; + + // Enter target scope for compilation auto f_enter = ffi::Function::GetGlobal("target.TargetEnterScope"); (*f_enter)(target); - if (auto f = ffi::Function::GetGlobal("tvm_callback_cuda_compile")) { - ptx = (*f)(code, target).cast(); - // Dirty matching to check PTX vs cubin. - // TODO(tqchen) more reliable checks - if (ptx[0] != '/') fmt = "cubin"; - } else { - ptx = NVRTCCompile(code, cg.need_include_path()); - } + + // Compile CUDA code via Python callback + compiled = (*f_compile)(code, target).cast(); + // Dirty matching to check PTX vs cubin. + // TODO(tqchen) more reliable checks + if (compiled[0] != '/') fmt = "cubin"; + // Exit target scope auto f_exit = ffi::Function::GetGlobal("target.TargetExitScope"); (*f_exit)(target); - return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); + + return CUDAModuleCreate(compiled, fmt, ExtractFuncInfo(mod), code); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index a9cfad9ab6f5..86201a2a05e3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -310,10 +310,16 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n"; decl_stream << "#endif\n"; + // Emit type aliases, guarding int64_t/uint64_t for compatibility + decl_stream << "\n#ifdef __CUDACC_RTC__\n"; + decl_stream << "using int64_t = long long;\n"; + decl_stream << "using uint64_t = unsigned long long;\n"; + decl_stream << "#else\n"; decl_stream << "#include \n"; + decl_stream << "#endif\n"; decl_stream << "using uint = unsigned int;\n"; decl_stream << "using uchar = unsigned char;\n"; - decl_stream << "using ushort = unsigned short;\n"; + decl_stream << "using ushort = unsigned short;\n\n"; return CodeGenC::Finish(); } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 3f1fcbc2dcd7..682845e9e7e3 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -391,7 +391,9 @@ void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16 bool enable_fp8, bool enable_fp4) { if (enable_fp16 || enable_bf16) { stream << R"( -#include +template struct is_same { static constexpr bool value = false; }; +template struct is_same { static constexpr bool value = true; }; + template struct __align__(8) half4_bfloat164 { T x, y, z, w; @@ -401,7 +403,7 @@ struct __align__(8) half4_bfloat164 { if (enable_fp8) { stream << R"( __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e4m3& fp8x4) { - if constexpr (std::is_same_v) { + if constexpr (is_same::value) { __nv_fp8x2_e4m3 lo_part, hi_part; lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF); hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF); @@ -481,7 +483,7 @@ struct __align__(8) half4_bfloat164 { if (enable_fp4) { stream << R"( __host__ __device__ explicit half4_bfloat164(const __nv_fp4x4_e2m1& fp4x4) { - if constexpr (std::is_same_v) { + if constexpr (is_same::value) { __nv_fp4x2_storage_t lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF); __nv_fp4x2_storage_t hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF); TVec2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1)); diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 1b31e64414b1..177541da0820 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -20,6 +20,7 @@ import pytest import tvm +import tvm.contrib.nvcc import tvm.testing from tvm import te, topi from tvm.contrib.nvcc import have_bf16, have_fp16, have_int8 @@ -27,6 +28,31 @@ from tvm.script import tir as T +@pytest.fixture(autouse=True, params=["nvcc", "nvrtc"]) +def setup_cuda_compile_mode(request): + mode = request.param + if mode == "nvrtc": + try: + from cuda.bindings import nvrtc + except ImportError: + pytest.skip("cuda-python not available, skipping nvrtc tests") + + orig_func = tvm.contrib.nvcc.tvm_callback_cuda_compile + + def compile_mode_wrapper(code, target): + if mode == "nvcc": + return tvm.contrib.nvcc.compile_cuda(code, target_format="fatbin", compiler="nvcc") + elif mode == "nvrtc": + return tvm.contrib.nvcc.compile_cuda(code, target_format="cubin", compiler="nvrtc") + else: + raise ValueError(f"Unknown mode: {mode}") + + tvm.register_global_func("tvm_callback_cuda_compile", compile_mode_wrapper, override=True) + # yield back to the original function so that each test runs twice + yield + tvm.register_global_func("tvm_callback_cuda_compile", orig_func, override=True) + + @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_cuda_vectorize_add(): @@ -201,13 +227,13 @@ def check_cuda(n, value, lanes): fun(a) np.testing.assert_equal(a.numpy(), np_a) - check_cuda(64, np.int8(0xAB), 4) + check_cuda(64, np.uint8(0xAB).view(np.int8), 4) check_cuda(64, 0, 4) check_cuda(64, -3, 4) - check_cuda(64, np.int8(0xAB), 3) + check_cuda(64, np.uint8(0xAB).view(np.int8), 3) check_cuda(64, 0, 3) check_cuda(64, -3, 3) - check_cuda(64, np.int8(0xAB), 2) + check_cuda(64, np.uint8(0xAB).view(np.int8), 2) check_cuda(64, 0, 2) check_cuda(64, -3, 2) diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index d9976e05e50b..029eb8fe824a 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -16,13 +16,15 @@ # under the License. """Basic tests for a Disco nvshmem support""" # pylint: disable=missing-docstring -import tempfile - import numpy as np import pytest + +import shutil import subprocess -import threading import sys +import tempfile +import threading +import multiprocessing from multiprocessing import Process from typing import Any, Callable, List @@ -160,7 +162,8 @@ def main(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): T.writes(B[v1, v0]) B[v1, v0] = A[v0, v1] - with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = tempfile.mkdtemp() + try: path = tmpdir + "/test.so" A_np = np.arange(8 * 16).astype("float32").reshape([8, 16]) B_np = np.zeros((16, 8), dtype="float32") @@ -180,9 +183,12 @@ def main(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): # finish the execution sess._sync_all() - finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") - finalize_dfunc() - sess.sync_worker_0() + finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") + finalize_dfunc() + sess.sync_worker_0() + finally: + sess.shutdown() + shutil.rmtree(tmpdir, ignore_errors=True) if __name__ == "__main__": @@ -190,14 +196,24 @@ def main(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): # or `nvshmem_init_thread` in the same program results in undefined behavior. # So we always create a new process to run the test. Then no repeated nvshmem # init happens in the same process, since the worker0 may share the same process. + + # Use 'spawn' start method to avoid inheriting CUDA state from parent process + # 'fork' (default on Linux) can cause issues with CUDA contexts in child processes + multiprocessing.set_start_method("spawn", force=True) + for session_kind in [create_socket_session, di.ProcessSession]: for num_workers in [2, 4]: for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]: p = Process(target=test_func, args=[session_kind, num_workers]) p.start() p.join() + # Ensure the process finished successfully + assert ( + p.exitcode == 0 + ), f"Test {test_func.__name__} failed with exit code {p.exitcode}" # testing compilation flow p = Process(target=test_nvshmem_compile) p.start() p.join() + assert p.exitcode == 0, f"Test test_nvshmem_compile failed with exit code {p.exitcode}" diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index 0855afcfd64a..aa7e2b357564 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -256,10 +256,17 @@ def test_inject_async_copy_barrier(): #else #define TVM_ENABLE_L2_PREFETCH 0 #endif + +#ifdef __CUDACC_RTC__ +using int64_t = long long; +using uint64_t = unsigned long long; +#else #include +#endif using uint = unsigned int; using uchar = unsigned char; using ushort = unsigned short; + extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C); extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64];