Skip to content

Commit 95777c4

Browse files
committed
avoid silly, redundant lock
1 parent 2afcb20 commit 95777c4

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from cuda.core.experimental._utils import ComputeCapability, CUDAError, driver, handle_return, precondition, runtime
1212

1313
_tls = threading.local()
14-
_tls_lock = threading.Lock()
14+
_lock = threading.Lock()
15+
_is_cuInit = False
1516

1617

1718
class DeviceProperties:
@@ -938,6 +939,12 @@ class Device:
938939
__slots__ = ("_id", "_mr", "_has_inited", "_properties")
939940

940941
def __new__(cls, device_id=None):
942+
global _is_cuInit
943+
if _is_cuInit is False:
944+
with _lock:
945+
handle_return(driver.cuInit(0))
946+
_is_cuInit = True
947+
941948
# important: creating a Device instance does not initialize the GPU!
942949
if device_id is None:
943950
device_id = handle_return(runtime.cudaGetDevice())
@@ -948,27 +955,26 @@ def __new__(cls, device_id=None):
948955
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")
949956

950957
# ensure Device is singleton
951-
with _tls_lock:
952-
if not hasattr(_tls, "devices"):
953-
total = handle_return(runtime.cudaGetDeviceCount())
954-
_tls.devices = []
955-
for dev_id in range(total):
956-
dev = super().__new__(cls)
957-
dev._id = dev_id
958-
# If the device is in TCC mode, or does not support memory pools for some other reason,
959-
# use the SynchronousMemoryResource which does not use memory pools.
960-
if (
961-
handle_return(
962-
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
963-
)
964-
) == 1:
965-
dev._mr = _DefaultAsyncMempool(dev_id)
966-
else:
967-
dev._mr = _SynchronousMemoryResource(dev_id)
968-
969-
dev._has_inited = False
970-
dev._properties = None
971-
_tls.devices.append(dev)
958+
if not hasattr(_tls, "devices"):
959+
total = handle_return(runtime.cudaGetDeviceCount())
960+
_tls.devices = []
961+
for dev_id in range(total):
962+
dev = super().__new__(cls)
963+
dev._id = dev_id
964+
# If the device is in TCC mode, or does not support memory pools for some other reason,
965+
# use the SynchronousMemoryResource which does not use memory pools.
966+
if (
967+
handle_return(
968+
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
969+
)
970+
) == 1:
971+
dev._mr = _DefaultAsyncMempool(dev_id)
972+
else:
973+
dev._mr = _SynchronousMemoryResource(dev_id)
974+
975+
dev._has_inited = False
976+
dev._properties = None
977+
_tls.devices.append(dev)
972978

973979
return _tls.devices[device_id]
974980

cuda_core/tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def _device_unset_current():
4242
return
4343
handle_return(driver.cuCtxPopCurrent())
4444
if hasattr(_device._tls, "devices"):
45-
with _device._tls_lock:
46-
del _device._tls.devices
45+
del _device._tls.devices
4746

4847

4948
@pytest.fixture(scope="function")

0 commit comments

Comments
 (0)