Skip to content

Commit 3a2260a

Browse files
committed
reduction
1 parent 897b07b commit 3a2260a

File tree

1 file changed

+30
-55
lines changed

1 file changed

+30
-55
lines changed

cuda_core/build_hooks.py

Lines changed: 30 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
@functools.cache
2929
def _get_cuda_paths():
3030
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
31-
if not CUDA_PATH:
32-
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
31+
if CUDA_PATH is None:
32+
return None
3333
CUDA_PATH = CUDA_PATH.split(os.pathsep)
3434
print("CUDA paths:", CUDA_PATH, flush=True)
3535
return CUDA_PATH
@@ -46,7 +46,11 @@ def _get_cuda_version_from_cuda_h(cuda_home=None):
4646
Returns the integer (e.g. 13000) or None if not found / on error.
4747
"""
4848
if cuda_home is None:
49-
cuda_home = _get_cuda_paths()[0]
49+
cuda_home = _get_cuda_paths()
50+
if cuda_home is None:
51+
return None
52+
else:
53+
cuda_home = cuda_home[0]
5054

5155
cuda_h = pathlib.Path(cuda_home) / "include" / "cuda.h"
5256
if not cuda_h.is_file():
@@ -62,58 +66,32 @@ def _get_cuda_version_from_cuda_h(cuda_home=None):
6266
if not m:
6367
return None
6468
print(f"CUDA_VERSION from {cuda_h}:", m.group(1), flush=True)
65-
return m.group(1)
66-
69+
return int(m.group(1))
6770

68-
def _get_cuda_driver_version_linux():
69-
"""
70-
Linux-only. Try to load `libcuda.so` via standard dynamic library lookup
71-
and call `CUresult cuDriverGetVersion(int* driverVersion)`.
7271

73-
Returns:
74-
int : driver version (e.g., 12040 for 12.4), if successful.
75-
None : on any failure (load error, missing symbol, non-success CUresult).
72+
def _get_cuda_driver_version():
7673
"""
77-
CUDA_SUCCESS = 0
78-
79-
libcuda_so = "libcuda.so.1"
80-
cdll_mode = os.RTLD_NOW | os.RTLD_GLOBAL
81-
try:
82-
# Use system search paths only; do not provide an absolute path.
83-
# Make symbols globally available to any dependent libraries.
84-
lib = ctypes.CDLL(libcuda_so, mode=cdll_mode)
85-
except OSError:
86-
return None
74+
Try to load ``libcuda.so`` or ``nvcuda.dll`` via standard dynamic library lookup
75+
and call ``cuDriverGetVersion``.
8776
88-
# int cuDriverGetVersion(int* driverVersion);
89-
lib.cuDriverGetVersion.restype = ctypes.c_int # CUresult
90-
lib.cuDriverGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
91-
92-
out = ctypes.c_int(0)
93-
rc = lib.cuDriverGetVersion(ctypes.byref(out))
94-
if rc != CUDA_SUCCESS:
95-
return None
96-
97-
print(f"CUDA_VERSION from {libcuda_so}:", int(out.value), flush=True)
98-
return int(out.value)
99-
100-
101-
def _get_cuda_driver_version_windows():
102-
"""
103-
Windows-only. Load `nvcuda.dll` via normal system search and call
104-
CUresult cuDriverGetVersion(int* driverVersion).
105-
106-
Returns:
107-
int : driver version (e.g., 12040 for 12.4), if successful.
108-
None : on any failure (load error, missing symbol, non-success CUresult).
77+
Returns the integer (e.g. 13000) or None if not found / on error.
10978
"""
11079
CUDA_SUCCESS = 0
11180

112-
try:
113-
# WinDLL => stdcall (CUDAAPI on Windows), matches CUDA Driver API.
114-
lib = ctypes.WinDLL("nvcuda.dll")
115-
except OSError:
116-
return None
81+
if sys.platform == "win32":
82+
try:
83+
# WinDLL => stdcall (CUDAAPI on Windows), matches CUDA Driver API.
84+
lib = ctypes.WinDLL("nvcuda.dll")
85+
except OSError:
86+
return None
87+
else:
88+
cdll_mode = os.RTLD_NOW | os.RTLD_GLOBAL
89+
try:
90+
# Use system search paths only; do not provide an absolute path.
91+
# Make symbols globally available to any dependent libraries.
92+
lib = ctypes.CDLL("libcuda.so.1", mode=cdll_mode)
93+
except OSError:
94+
return None
11795

11896
# int cuDriverGetVersion(int* driverVersion);
11997
cuDriverGetVersion = lib.cuDriverGetVersion
@@ -125,7 +103,7 @@ def _get_cuda_driver_version_windows():
125103
if rc != CUDA_SUCCESS:
126104
return None
127105

128-
print("CUDA_VERSION from nvcuda.dll:", int(out.value), flush=True)
106+
print("CUDA_VERSION from driver:", int(out.value), flush=True)
129107
return int(out.value)
130108

131109

@@ -145,14 +123,11 @@ def _get_proper_cuda_bindings_major_version() -> str:
145123
return cuda_major
146124

147125
cuda_version = _get_cuda_version_from_cuda_h()
148-
if cuda_version and len(cuda_version) > 3:
149-
return cuda_version[:-3]
126+
if cuda_version:
127+
return str(cuda_version // 1000)
150128

151129
# also for local development
152-
if sys.platform == "win32":
153-
cuda_version = _get_cuda_driver_version_windows()
154-
else:
155-
cuda_version = _get_cuda_driver_version_linux()
130+
cuda_version = _get_cuda_driver_version()
156131
if cuda_version:
157132
return str(cuda_version // 1000)
158133

0 commit comments

Comments
 (0)