2828@functools .cache
2929def _get_cuda_paths ():
3030 CUDA_PATH = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" , None ))
31- if not CUDA_PATH :
32- raise RuntimeError ( "Environment variable CUDA_PATH or CUDA_HOME is not set" )
31+ if CUDA_PATH is None :
32+ return None
3333 CUDA_PATH = CUDA_PATH .split (os .pathsep )
3434 print ("CUDA paths:" , CUDA_PATH , flush = True )
3535 return CUDA_PATH
@@ -46,7 +46,11 @@ def _get_cuda_version_from_cuda_h(cuda_home=None):
4646 Returns the integer (e.g. 13000) or None if not found / on error.
4747 """
4848 if cuda_home is None :
49- cuda_home = _get_cuda_paths ()[0 ]
49+ cuda_home = _get_cuda_paths ()
50+ if cuda_home is None :
51+ return None
52+ else :
53+ cuda_home = cuda_home [0 ]
5054
5155 cuda_h = pathlib .Path (cuda_home ) / "include" / "cuda.h"
5256 if not cuda_h .is_file ():
@@ -62,58 +66,32 @@ def _get_cuda_version_from_cuda_h(cuda_home=None):
6266 if not m :
6367 return None
6468 print (f"CUDA_VERSION from { cuda_h } :" , m .group (1 ), flush = True )
65- return m .group (1 )
66-
69+ return int (m .group (1 ))
6770
68- def _get_cuda_driver_version_linux ():
69- """
70- Linux-only. Try to load `libcuda.so` via standard dynamic library lookup
71- and call `CUresult cuDriverGetVersion(int* driverVersion)`.
7271
73- Returns:
74- int : driver version (e.g., 12040 for 12.4), if successful.
75- None : on any failure (load error, missing symbol, non-success CUresult).
72+ def _get_cuda_driver_version ():
7673 """
77- CUDA_SUCCESS = 0
78-
79- libcuda_so = "libcuda.so.1"
80- cdll_mode = os .RTLD_NOW | os .RTLD_GLOBAL
81- try :
82- # Use system search paths only; do not provide an absolute path.
83- # Make symbols globally available to any dependent libraries.
84- lib = ctypes .CDLL (libcuda_so , mode = cdll_mode )
85- except OSError :
86- return None
74+ Try to load ``libcuda.so`` or ``nvcuda.dll`` via standard dynamic library lookup
75+ and call ``cuDriverGetVersion``.
8776
88- # int cuDriverGetVersion(int* driverVersion);
89- lib .cuDriverGetVersion .restype = ctypes .c_int # CUresult
90- lib .cuDriverGetVersion .argtypes = [ctypes .POINTER (ctypes .c_int )]
91-
92- out = ctypes .c_int (0 )
93- rc = lib .cuDriverGetVersion (ctypes .byref (out ))
94- if rc != CUDA_SUCCESS :
95- return None
96-
97- print (f"CUDA_VERSION from { libcuda_so } :" , int (out .value ), flush = True )
98- return int (out .value )
99-
100-
101- def _get_cuda_driver_version_windows ():
102- """
103- Windows-only. Load `nvcuda.dll` via normal system search and call
104- CUresult cuDriverGetVersion(int* driverVersion).
105-
106- Returns:
107- int : driver version (e.g., 12040 for 12.4), if successful.
108- None : on any failure (load error, missing symbol, non-success CUresult).
77+ Returns the integer (e.g. 13000) or None if not found / on error.
10978 """
11079 CUDA_SUCCESS = 0
11180
112- try :
113- # WinDLL => stdcall (CUDAAPI on Windows), matches CUDA Driver API.
114- lib = ctypes .WinDLL ("nvcuda.dll" )
115- except OSError :
116- return None
81+ if sys .platform == "win32" :
82+ try :
83+ # WinDLL => stdcall (CUDAAPI on Windows), matches CUDA Driver API.
84+ lib = ctypes .WinDLL ("nvcuda.dll" )
85+ except OSError :
86+ return None
87+ else :
88+ cdll_mode = os .RTLD_NOW | os .RTLD_GLOBAL
89+ try :
90+ # Use system search paths only; do not provide an absolute path.
91+ # Make symbols globally available to any dependent libraries.
92+ lib = ctypes .CDLL ("libcuda.so.1" , mode = cdll_mode )
93+ except OSError :
94+ return None
11795
11896 # int cuDriverGetVersion(int* driverVersion);
11997 cuDriverGetVersion = lib .cuDriverGetVersion
@@ -125,7 +103,7 @@ def _get_cuda_driver_version_windows():
125103 if rc != CUDA_SUCCESS :
126104 return None
127105
128- print ("CUDA_VERSION from nvcuda.dll :" , int (out .value ), flush = True )
106+ print ("CUDA_VERSION from driver :" , int (out .value ), flush = True )
129107 return int (out .value )
130108
131109
@@ -145,14 +123,11 @@ def _get_proper_cuda_bindings_major_version() -> str:
145123 return cuda_major
146124
147125 cuda_version = _get_cuda_version_from_cuda_h ()
148- if cuda_version and len ( cuda_version ) > 3 :
149- return cuda_version [: - 3 ]
126+ if cuda_version :
127+ return str ( cuda_version // 1000 )
150128
151129 # also for local development
152- if sys .platform == "win32" :
153- cuda_version = _get_cuda_driver_version_windows ()
154- else :
155- cuda_version = _get_cuda_driver_version_linux ()
130+ cuda_version = _get_cuda_driver_version ()
156131 if cuda_version :
157132 return str (cuda_version // 1000 )
158133
0 commit comments