Skip to content

Commit 87405ad

Browse files
committed
avoid using cudart APIs in Device constructor
1 parent 2afcb20 commit 87405ad

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from cuda.core.experimental._utils import ComputeCapability, CUDAError, driver, handle_return, precondition, runtime
1212

1313
_tls = threading.local()
14-
_tls_lock = threading.Lock()
14+
_lock = threading.Lock()
15+
_is_cuInit = False
1516

1617

1718
class DeviceProperties:
@@ -938,37 +939,51 @@ class Device:
938939
__slots__ = ("_id", "_mr", "_has_inited", "_properties")
939940

940941
def __new__(cls, device_id=None):
942+
global _is_cuInit
943+
if _is_cuInit is False:
944+
with _lock:
945+
handle_return(driver.cuInit(0))
946+
_is_cuInit = True
947+
941948
# important: creating a Device instance does not initialize the GPU!
942949
if device_id is None:
943-
device_id = handle_return(runtime.cudaGetDevice())
950+
err, dev = driver.cuCtxGetDevice()
951+
if err == 0:
952+
device_id = int(dev)
953+
else:
954+
ctx = handle_return(driver.cuCtxGetCurrent())
955+
assert int(ctx) == 0
956+
device_id = 0 # cudart behavior
944957
assert isinstance(device_id, int), f"{device_id=}"
945958
else:
946-
total = handle_return(runtime.cudaGetDeviceCount())
959+
total = handle_return(driver.cuDeviceGetCount())
947960
if not isinstance(device_id, int) or not (0 <= device_id < total):
948961
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")
949962

950963
# ensure Device is singleton
951-
with _tls_lock:
952-
if not hasattr(_tls, "devices"):
953-
total = handle_return(runtime.cudaGetDeviceCount())
954-
_tls.devices = []
955-
for dev_id in range(total):
956-
dev = super().__new__(cls)
957-
dev._id = dev_id
958-
# If the device is in TCC mode, or does not support memory pools for some other reason,
959-
# use the SynchronousMemoryResource which does not use memory pools.
960-
if (
961-
handle_return(
962-
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
964+
if not hasattr(_tls, "devices"):
965+
total = handle_return(driver.cuDeviceGetCount())
966+
_tls.devices = []
967+
for dev_id in range(total):
968+
dev = super().__new__(cls)
969+
970+
dev._id = dev_id
971+
# If the device is in TCC mode, or does not support memory pools for some other reason,
972+
# use the SynchronousMemoryResource which does not use memory pools.
973+
if (
974+
handle_return(
975+
driver.cuDeviceGetAttribute(
976+
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
963977
)
964-
) == 1:
965-
dev._mr = _DefaultAsyncMempool(dev_id)
966-
else:
967-
dev._mr = _SynchronousMemoryResource(dev_id)
968-
969-
dev._has_inited = False
970-
dev._properties = None
971-
_tls.devices.append(dev)
978+
)
979+
) == 1:
980+
dev._mr = _DefaultAsyncMempool(dev_id)
981+
else:
982+
dev._mr = _SynchronousMemoryResource(dev_id)
983+
dev._has_inited = False
984+
dev._properties = None
985+
986+
_tls.devices.append(dev)
972987

973988
return _tls.devices[device_id]
974989

cuda_core/tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def _device_unset_current():
4242
return
4343
handle_return(driver.cuCtxPopCurrent())
4444
if hasattr(_device._tls, "devices"):
45-
with _device._tls_lock:
46-
del _device._tls.devices
45+
del _device._tls.devices
4746

4847

4948
@pytest.fixture(scope="function")

0 commit comments

Comments
 (0)