1111from cuda .core .experimental ._utils import ComputeCapability , CUDAError , driver , handle_return , precondition , runtime
1212
1313_tls = threading .local ()
14- _tls_lock = threading .Lock ()
14+ _lock = threading .Lock ()
15+ _is_cuInit = False
1516
1617
1718class DeviceProperties :
@@ -938,6 +939,12 @@ class Device:
938939 __slots__ = ("_id" , "_mr" , "_has_inited" , "_properties" )
939940
940941 def __new__ (cls , device_id = None ):
942+ global _is_cuInit
943+ if _is_cuInit is False :
944+ with _lock :
945+ handle_return (driver .cuInit (0 ))
946+ _is_cuInit = True
947+
941948 # important: creating a Device instance does not initialize the GPU!
942949 if device_id is None :
943950 device_id = handle_return (runtime .cudaGetDevice ())
@@ -948,27 +955,26 @@ def __new__(cls, device_id=None):
948955 raise ValueError (f"device_id must be within [0, { total } ), got { device_id } " )
949956
950957 # ensure Device is singleton
951- with _tls_lock :
952- if not hasattr (_tls , "devices" ):
953- total = handle_return (runtime .cudaGetDeviceCount ())
954- _tls .devices = []
955- for dev_id in range (total ):
956- dev = super ().__new__ (cls )
957- dev ._id = dev_id
958- # If the device is in TCC mode, or does not support memory pools for some other reason,
959- # use the SynchronousMemoryResource which does not use memory pools.
960- if (
961- handle_return (
962- runtime .cudaDeviceGetAttribute (runtime .cudaDeviceAttr .cudaDevAttrMemoryPoolsSupported , 0 )
963- )
964- ) == 1 :
965- dev ._mr = _DefaultAsyncMempool (dev_id )
966- else :
967- dev ._mr = _SynchronousMemoryResource (dev_id )
968-
969- dev ._has_inited = False
970- dev ._properties = None
971- _tls .devices .append (dev )
958+ if not hasattr (_tls , "devices" ):
959+ total = handle_return (runtime .cudaGetDeviceCount ())
960+ _tls .devices = []
961+ for dev_id in range (total ):
962+ dev = super ().__new__ (cls )
963+ dev ._id = dev_id
964+ # If the device is in TCC mode, or does not support memory pools for some other reason,
965+ # use the SynchronousMemoryResource which does not use memory pools.
966+ if (
967+ handle_return (
968+ runtime .cudaDeviceGetAttribute (runtime .cudaDeviceAttr .cudaDevAttrMemoryPoolsSupported , 0 )
969+ )
970+ ) == 1 :
971+ dev ._mr = _DefaultAsyncMempool (dev_id )
972+ else :
973+ dev ._mr = _SynchronousMemoryResource (dev_id )
974+
975+ dev ._has_inited = False
976+ dev ._properties = None
977+ _tls .devices .append (dev )
972978
973979 return _tls .devices [device_id ]
974980
0 commit comments