Skip to content
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 118 additions & 19 deletions cuda_core/build_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
#
# TODO: also implement PEP-660 API hooks

import ctypes
import functools
import glob
import os
import pathlib
import re
import subprocess
import sys

from Cython.Build import cythonize
from setuptools import Extension
Expand All @@ -24,6 +26,110 @@
get_requires_for_build_sdist = _build_meta.get_requires_for_build_sdist


@functools.cache
def _get_cuda_paths():
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
if not CUDA_PATH:
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
CUDA_PATH = CUDA_PATH.split(os.pathsep)
print("CUDA paths:", CUDA_PATH, flush=True)
return CUDA_PATH


@functools.cache
def _get_cuda_version_from_cuda_h(cuda_home=None):
"""
Given CUDA_HOME, try to extract the CUDA_VERSION macro from include/cuda.h.

Example line in cuda.h:
#define CUDA_VERSION 13000

Returns the integer (e.g. 13000) or None if not found / on error.
"""
if cuda_home is None:
cuda_home = _get_cuda_paths()[0]

cuda_h = pathlib.Path(cuda_home) / "include" / "cuda.h"
if not cuda_h.is_file():
return None

try:
text = cuda_h.read_text(encoding="utf-8", errors="ignore")
except OSError:
# Permissions issue, unreadable file, etc.
return None

m = re.search(r"^\s*#define\s+CUDA_VERSION\s+(\d+)", text, re.MULTILINE)
if not m:
return None
print(f"CUDA_VERSION from {cuda_h}:", m.group(1), flush=True)
return m.group(1)


def _get_cuda_driver_version_linux():
"""
Linux-only. Try to load `libcuda.so` via standard dynamic library lookup
and call `CUresult cuDriverGetVersion(int* driverVersion)`.

Returns:
int : driver version (e.g., 12040 for 12.4), if successful.
None : on any failure (load error, missing symbol, non-success CUresult).
"""
CUDA_SUCCESS = 0

libcuda_so = "libcuda.so.1"
cdll_mode = os.RTLD_NOW | os.RTLD_GLOBAL
try:
# Use system search paths only; do not provide an absolute path.
# Make symbols globally available to any dependent libraries.
lib = ctypes.CDLL(libcuda_so, mode=cdll_mode)
except OSError:
return None

# int cuDriverGetVersion(int* driverVersion);
lib.cuDriverGetVersion.restype = ctypes.c_int # CUresult
lib.cuDriverGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]

out = ctypes.c_int(0)
rc = lib.cuDriverGetVersion(ctypes.byref(out))
if rc != CUDA_SUCCESS:
return None

print(f"CUDA_VERSION from {libcuda_so}:", int(out.value), flush=True)
return int(out.value)


def _get_cuda_driver_version_windows():
"""
Windows-only. Load `nvcuda.dll` via normal system search and call
CUresult cuDriverGetVersion(int* driverVersion).

Returns:
int : driver version (e.g., 12040 for 12.4), if successful.
None : on any failure (load error, missing symbol, non-success CUresult).
"""
CUDA_SUCCESS = 0

try:
# WinDLL => stdcall (CUDAAPI on Windows), matches CUDA Driver API.
lib = ctypes.WinDLL("nvcuda.dll")
except OSError:
return None

# int cuDriverGetVersion(int* driverVersion);
cuDriverGetVersion = lib.cuDriverGetVersion
cuDriverGetVersion.restype = ctypes.c_int # CUresult
cuDriverGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]

out = ctypes.c_int(0)
rc = cuDriverGetVersion(ctypes.byref(out))
if rc != CUDA_SUCCESS:
return None

print("CUDA_VERSION from nvcuda.dll:", int(out.value), flush=True)
return int(out.value)


@functools.cache
def _get_proper_cuda_bindings_major_version() -> str:
# for local development (with/without build isolation)
Expand All @@ -39,15 +145,17 @@ def _get_proper_cuda_bindings_major_version() -> str:
if cuda_major is not None:
return cuda_major

cuda_version = _get_cuda_version_from_cuda_h()
if cuda_version and len(cuda_version) > 3:
return cuda_version[:-3]

# also for local development
try:
out = subprocess.run("nvidia-smi", env=os.environ, capture_output=True, check=True) # noqa: S603, S607
m = re.search(r"CUDA Version:\s*([\d\.]+)", out.stdout.decode())
if m:
return m.group(1).split(".")[0]
except (FileNotFoundError, subprocess.CalledProcessError):
# the build machine has no driver installed
pass
if sys.platform == "win32":
cuda_version = _get_cuda_driver_version_windows()
else:
cuda_version = _get_cuda_driver_version_linux()
if cuda_version:
return str(cuda_version // 1000)

# default fallback
return "13"
Expand All @@ -73,20 +181,11 @@ def strip_prefix_suffix(filename):

module_names = (strip_prefix_suffix(f) for f in ext_files)

@functools.cache
def get_cuda_paths():
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
if not CUDA_PATH:
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
CUDA_PATH = CUDA_PATH.split(os.pathsep)
print("CUDA paths:", CUDA_PATH)
return CUDA_PATH

ext_modules = tuple(
Extension(
f"cuda.core.experimental.{mod.replace(os.path.sep, '.')}",
sources=[f"cuda/core/experimental/{mod}.pyx"],
include_dirs=list(os.path.join(root, "include") for root in get_cuda_paths()),
include_dirs=list(os.path.join(root, "include") for root in _get_cuda_paths()),
language="c++",
)
for mod in module_names
Expand Down