diff --git a/cuda_core/build_hooks.py b/cuda_core/build_hooks.py index e38f5676d..9cc2c5948 100644 --- a/cuda_core/build_hooks.py +++ b/cuda_core/build_hooks.py @@ -7,11 +7,13 @@ # - https://setuptools.pypa.io/en/latest/build_meta.html#dynamic-build-dependencies-and-other-build-meta-tweaks # Specifically, there are 5 APIs required to create a proper build backend, see below. +import ctypes import functools import glob import os +import pathlib import re -import subprocess +import sys from Cython.Build import cythonize from setuptools import Extension @@ -23,6 +25,88 @@ get_requires_for_build_sdist = _build_meta.get_requires_for_build_sdist +@functools.cache +def _get_cuda_paths(): + CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None)) + if CUDA_PATH is None: + return None + CUDA_PATH = CUDA_PATH.split(os.pathsep) + print("CUDA paths:", CUDA_PATH, flush=True) + return CUDA_PATH + + +@functools.cache +def _get_cuda_version_from_cuda_h(cuda_home=None): + """ + Given CUDA_HOME, try to extract the CUDA_VERSION macro from include/cuda.h. + + Example line in cuda.h: + #define CUDA_VERSION 13000 + + Returns the integer (e.g. 13000) or None if not found / on error. + """ + if cuda_home is None: + cuda_home = _get_cuda_paths() + if cuda_home is None: + return None + else: + cuda_home = cuda_home[0] + + cuda_h = pathlib.Path(cuda_home) / "include" / "cuda.h" + if not cuda_h.is_file(): + return None + + try: + text = cuda_h.read_text(encoding="utf-8", errors="ignore") + except OSError: + # Permissions issue, unreadable file, etc. + return None + + m = re.search(r"^\s*#define\s+CUDA_VERSION\s+(\d+)", text, re.MULTILINE) + if not m: + return None + print(f"CUDA_VERSION from {cuda_h}:", m.group(1), flush=True) + return int(m.group(1)) + + +def _get_cuda_driver_version(): + """ + Try to load ``libcuda.so`` or ``nvcuda.dll`` via standard dynamic library lookup + and call ``cuDriverGetVersion``. + + Returns the integer (e.g. 13000) or None if not found / on error. + """ + CUDA_SUCCESS = 0 + + if sys.platform == "win32": + try: + # WinDLL => stdcall (CUDAAPI on Windows), matches CUDA Driver API. + lib = ctypes.WinDLL("nvcuda.dll") + except OSError: + return None + else: + cdll_mode = os.RTLD_NOW | os.RTLD_GLOBAL + try: + # Use system search paths only; do not provide an absolute path. + # Make symbols globally available to any dependent libraries. + lib = ctypes.CDLL("libcuda.so.1", mode=cdll_mode) + except OSError: + return None + + # int cuDriverGetVersion(int* driverVersion); + cuDriverGetVersion = lib.cuDriverGetVersion + cuDriverGetVersion.restype = ctypes.c_int # CUresult + cuDriverGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] + + out = ctypes.c_int(0) + rc = cuDriverGetVersion(ctypes.byref(out)) + if rc != CUDA_SUCCESS: + return None + + print("CUDA_VERSION from driver:", int(out.value), flush=True) + return int(out.value) + + @functools.cache def _get_proper_cuda_bindings_major_version() -> str: # for local development (with/without build isolation) @@ -38,15 +122,14 @@ def _get_proper_cuda_bindings_major_version() -> str: if cuda_major is not None: return cuda_major + cuda_version = _get_cuda_version_from_cuda_h() + if cuda_version: + return str(cuda_version // 1000) + # also for local development - try: - out = subprocess.run("nvidia-smi", env=os.environ, capture_output=True, check=True) # noqa: S603, S607 - m = re.search(r"CUDA Version:\s*([\d\.]+)", out.stdout.decode()) - if m: - return m.group(1).split(".")[0] - except (FileNotFoundError, subprocess.CalledProcessError): - # the build machine has no driver installed - pass + cuda_version = _get_cuda_driver_version() + if cuda_version: + return str(cuda_version // 1000) # default fallback return "13" @@ -75,20 +158,11 @@ def strip_prefix_suffix(filename): module_names = (strip_prefix_suffix(f) for f in ext_files) - @functools.cache - def get_cuda_paths(): - CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None)) - if not CUDA_PATH: - raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set") - CUDA_PATH = CUDA_PATH.split(os.pathsep) - print("CUDA paths:", CUDA_PATH) - return CUDA_PATH - ext_modules = tuple( Extension( f"cuda.core.experimental.{mod.replace(os.path.sep, '.')}", sources=[f"cuda/core/experimental/{mod}.pyx"], - include_dirs=list(os.path.join(root, "include") for root in get_cuda_paths()), + include_dirs=list(os.path.join(root, "include") for root in _get_cuda_paths()), language="c++", ) for mod in module_names