@@ -964,19 +964,19 @@ def __new__(cls, device_id=None):
964964 ctx = handle_return (driver .cuCtxGetCurrent ())
965965 assert int (ctx ) == 0
966966 device_id = 0 # cudart behavior
967- assert isinstance (device_id , int ), f"{ device_id = } "
968967 else :
969968 total = handle_return (driver .cuDeviceGetCount ())
970969 if not isinstance (device_id , int ) or not (0 <= device_id < total ):
971970 raise ValueError (f"device_id must be within [0, { total } ), got { device_id } " )
972971
973972 # ensure Device is singleton
974- if not hasattr (_tls , "devices" ):
973+ try :
974+ devices = _tls .devices
975+ except AttributeError :
975976 total = handle_return (driver .cuDeviceGetCount ())
976- _tls .devices = []
977+ devices = _tls .devices = []
977978 for dev_id in range (total ):
978979 dev = super ().__new__ (cls )
979-
980980 dev ._id = dev_id
981981 # If the device is in TCC mode, or does not support memory pools for some other reason,
982982 # use the SynchronousMemoryResource which does not use memory pools.
@@ -990,12 +990,12 @@ def __new__(cls, device_id=None):
990990 dev ._mr = _DefaultAsyncMempool (dev_id )
991991 else :
992992 dev ._mr = _SynchronousMemoryResource (dev_id )
993+
993994 dev ._has_inited = False
994995 dev ._properties = None
996+ devices .append (dev )
995997
996- _tls .devices .append (dev )
997-
998- return _tls .devices [device_id ]
998+ return devices [device_id ]
999999
10001000 def _check_context_initialized (self , * args , ** kwargs ):
10011001 if not self ._has_inited :
0 commit comments