|
11 | 11 | from cuda.core.experimental._utils import ComputeCapability, CUDAError, driver, handle_return, precondition, runtime |
12 | 12 |
|
13 | 13 | _tls = threading.local() |
14 | | -_tls_lock = threading.Lock() |
| 14 | +_lock = threading.Lock() |
| 15 | +_is_cuInit = False |
15 | 16 |
|
16 | 17 |
|
17 | 18 | class DeviceProperties: |
@@ -938,37 +939,51 @@ class Device: |
938 | 939 | __slots__ = ("_id", "_mr", "_has_inited", "_properties") |
939 | 940 |
|
940 | 941 | def __new__(cls, device_id=None): |
| 942 | + global _is_cuInit |
| 943 | + if _is_cuInit is False: |
| 944 | + with _lock: |
| 945 | + handle_return(driver.cuInit(0)) |
| 946 | + _is_cuInit = True |
| 947 | + |
941 | 948 | # important: creating a Device instance does not initialize the GPU! |
942 | 949 | if device_id is None: |
943 | | - device_id = handle_return(runtime.cudaGetDevice()) |
| 950 | + err, dev = driver.cuCtxGetDevice() |
| 951 | + if err == 0: |
| 952 | + device_id = int(dev) |
| 953 | + else: |
| 954 | + ctx = handle_return(driver.cuCtxGetCurrent()) |
| 955 | + assert int(ctx) == 0 |
| 956 | + device_id = 0 # cudart behavior |
944 | 957 | assert isinstance(device_id, int), f"{device_id=}" |
945 | 958 | else: |
946 | | - total = handle_return(runtime.cudaGetDeviceCount()) |
| 959 | + total = handle_return(driver.cuDeviceGetCount()) |
947 | 960 | if not isinstance(device_id, int) or not (0 <= device_id < total): |
948 | 961 | raise ValueError(f"device_id must be within [0, {total}), got {device_id}") |
949 | 962 |
|
950 | 963 | # ensure Device is singleton |
951 | | - with _tls_lock: |
952 | | - if not hasattr(_tls, "devices"): |
953 | | - total = handle_return(runtime.cudaGetDeviceCount()) |
954 | | - _tls.devices = [] |
955 | | - for dev_id in range(total): |
956 | | - dev = super().__new__(cls) |
957 | | - dev._id = dev_id |
958 | | - # If the device is in TCC mode, or does not support memory pools for some other reason, |
959 | | - # use the SynchronousMemoryResource which does not use memory pools. |
960 | | - if ( |
961 | | - handle_return( |
962 | | - runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0) |
| 964 | + if not hasattr(_tls, "devices"): |
| 965 | + total = handle_return(driver.cuDeviceGetCount()) |
| 966 | + _tls.devices = [] |
| 967 | + for dev_id in range(total): |
| 968 | + dev = super().__new__(cls) |
| 969 | + |
| 970 | + dev._id = dev_id |
| 971 | + # If the device is in TCC mode, or does not support memory pools for some other reason, |
| 972 | + # use the SynchronousMemoryResource which does not use memory pools. |
| 973 | + if ( |
| 974 | + handle_return( |
| 975 | + driver.cuDeviceGetAttribute( |
| 976 | + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id |
963 | 977 | ) |
964 | | - ) == 1: |
965 | | - dev._mr = _DefaultAsyncMempool(dev_id) |
966 | | - else: |
967 | | - dev._mr = _SynchronousMemoryResource(dev_id) |
968 | | - |
969 | | - dev._has_inited = False |
970 | | - dev._properties = None |
971 | | - _tls.devices.append(dev) |
| 978 | + ) |
| 979 | + ) == 1: |
| 980 | + dev._mr = _DefaultAsyncMempool(dev_id) |
| 981 | + else: |
| 982 | + dev._mr = _SynchronousMemoryResource(dev_id) |
| 983 | + dev._has_inited = False |
| 984 | + dev._properties = None |
| 985 | + |
| 986 | + _tls.devices.append(dev) |
972 | 987 |
|
973 | 988 | return _tls.devices[device_id] |
974 | 989 |
|
|
0 commit comments