Skip to content

Commit 02c8e4e

Browse files
committed
inline @precondition to reduce overhead
1 parent 3495a1f commit 02c8e4e

File tree

2 files changed

+58
-35
lines changed

2 files changed

+58
-35
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
_check_driver_error,
1818
driver,
1919
handle_return,
20-
precondition,
2120
runtime,
2221
)
2322

@@ -1017,12 +1016,31 @@ def __new__(cls, device_id: Optional[int] = None):
10171016
except IndexError:
10181017
raise ValueError(f"device_id must be within [0, {len(devices)}), got {device_id}") from None
10191018

1020-
def _check_context_initialized(self, *args, **kwargs):
1019+
def _check_context_initialized(self):
10211020
if not self._has_inited:
10221021
raise CUDAError(
10231022
f"Device {self._id} is not yet initialized, perhaps you forgot to call .set_current() first?"
10241023
)
10251024

1025+
def _get_current_context(self, check_consistency=False) -> driver.CUcontext:
1026+
err, ctx = driver.cuCtxGetCurrent()
1027+
1028+
# TODO: We want to just call this:
1029+
#_check_driver_error(err)
1030+
# but even the simplest success check causes 50-100 ns. Wait until we cythonize this file...
1031+
if ctx is None:
1032+
_check_driver_error(err)
1033+
1034+
if int(ctx) == 0:
1035+
raise CUDAError("No context is bound to the calling CPU thread.")
1036+
if check_consistency:
1037+
err, dev = driver.cuCtxGetDevice()
1038+
if err != _SUCCESS:
1039+
handle_return((err,))
1040+
if int(dev) != self._id:
1041+
raise CUDAError("Internal error (current device is not equal to Device.device_id)")
1042+
return ctx
1043+
10261044
@property
10271045
def device_id(self) -> int:
10281046
"""Return device ordinal."""
@@ -1083,7 +1101,6 @@ def compute_capability(self) -> ComputeCapability:
10831101
return cc
10841102

10851103
@property
1086-
@precondition(_check_context_initialized)
10871104
def context(self) -> Context:
10881105
"""Return the current :obj:`~_context.Context` associated with this device.
10891106
@@ -1092,9 +1109,8 @@ def context(self) -> Context:
10921109
Device must be initialized.
10931110
10941111
"""
1095-
ctx = handle_return(driver.cuCtxGetCurrent())
1096-
if int(ctx) == 0:
1097-
raise CUDAError("No context is bound to the calling CPU thread.")
1112+
self._check_context_initialized()
1113+
ctx = self._get_current_context(check_consistency=True)
10981114
return Context._from_ctx(ctx, self._id)
10991115

11001116
@property
@@ -1206,7 +1222,6 @@ def create_context(self, options: ContextOptions = None) -> Context:
12061222
"""
12071223
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/189")
12081224

1209-
@precondition(_check_context_initialized)
12101225
def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions = None) -> Stream:
12111226
"""Create a Stream object.
12121227
@@ -1235,6 +1250,7 @@ def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions
12351250
Newly created stream object.
12361251
12371252
"""
1253+
self._check_context_initialized()
12381254
return Stream._init(obj=obj, options=options)
12391255

12401256
def create_event(self, options: Optional[EventOptions] = None) -> Event:
@@ -1255,12 +1271,10 @@ def create_event(self, options: Optional[EventOptions] = None) -> Event:
12551271
Newly created event object.
12561272
12571273
"""
1258-
ctx = driver.cuCtxGetCurrent()[1]
1259-
if int(ctx) == 0:
1260-
raise CUDAError("No context is bound to the calling CPU thread.")
1274+
self._check_context_initialized()
1275+
ctx = self._get_current_context()
12611276
return Event._init(self._id, ctx, options)
12621277

1263-
@precondition(_check_context_initialized)
12641278
def allocate(self, size, stream: Optional[Stream] = None) -> Buffer:
12651279
"""Allocate device memory from a specified stream.
12661280
@@ -1287,11 +1301,11 @@ def allocate(self, size, stream: Optional[Stream] = None) -> Buffer:
12871301
Newly created buffer object.
12881302
12891303
"""
1304+
self._check_context_initialized()
12901305
if stream is None:
12911306
stream = default_stream()
12921307
return self._mr.allocate(size, stream)
12931308

1294-
@precondition(_check_context_initialized)
12951309
def sync(self):
12961310
"""Synchronize the device.
12971311
@@ -1300,9 +1314,9 @@ def sync(self):
13001314
Device must be initialized.
13011315
13021316
"""
1317+
self._check_context_initialized()
13031318
handle_return(runtime.cudaDeviceSynchronize())
13041319

1305-
@precondition(_check_context_initialized)
13061320
def create_graph_builder(self) -> GraphBuilder:
13071321
"""Create a new :obj:`~_graph.GraphBuilder` object.
13081322
@@ -1312,4 +1326,5 @@ def create_graph_builder(self) -> GraphBuilder:
13121326
Newly created graph builder object.
13131327
13141328
"""
1329+
self._check_context_initialized()
13151330
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)

cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,26 +52,29 @@ def _reduce_3_tuple(t: tuple):
5252
return t[0] * t[1] * t[2]
5353

5454

55-
cpdef inline void _check_driver_error(error) except*:
56-
if error == driver.CUresult.CUDA_SUCCESS:
57-
return
55+
cdef object _DRIVER_SUCCESS = driver.CUresult.CUDA_SUCCESS
56+
57+
58+
cpdef inline int _check_driver_error(error) except?-1:
59+
if error == _DRIVER_SUCCESS:
60+
return 0
5861
name_err, name = driver.cuGetErrorName(error)
59-
if name_err != driver.CUresult.CUDA_SUCCESS:
62+
if name_err != _DRIVER_SUCCESS:
6063
raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
6164
name = name.decode()
6265
expl = DRIVER_CU_RESULT_EXPLANATIONS.get(int(error))
6366
if expl is not None:
6467
raise CUDAError(f"{name}: {expl}")
6568
desc_err, desc = driver.cuGetErrorString(error)
66-
if desc_err != driver.CUresult.CUDA_SUCCESS:
69+
if desc_err != _DRIVER_SUCCESS:
6770
raise CUDAError(f"{name}")
6871
desc = desc.decode()
6972
raise CUDAError(f"{name}: {desc}")
7073

7174

72-
cpdef inline void _check_runtime_error(error) except*:
75+
cpdef inline int _check_runtime_error(error) except?-1:
7376
if error == runtime.cudaError_t.cudaSuccess:
74-
return
77+
return 0
7578
name_err, name = runtime.cudaGetErrorName(error)
7679
if name_err != runtime.cudaError_t.cudaSuccess:
7780
raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
@@ -86,30 +89,35 @@ cpdef inline void _check_runtime_error(error) except*:
8689
raise CUDAError(f"{name}: {desc}")
8790

8891

89-
cdef inline void _check_error(error, handle=None) except*:
92+
cpdef inline int _check_nvrtc_error(error, handle=None) except?-1:
93+
if error == nvrtc.nvrtcResult.NVRTC_SUCCESS:
94+
return 0
95+
err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}"
96+
if handle is not None:
97+
_, logsize = nvrtc.nvrtcGetProgramLogSize(handle)
98+
log = b" " * logsize
99+
_ = nvrtc.nvrtcGetProgramLog(handle, log)
100+
err += f", compilation log:\n\n{log.decode('utf-8', errors='backslashreplace')}"
101+
raise NVRTCError(err)
102+
103+
104+
cdef inline int _check_error(error, handle=None) except?-1:
90105
if isinstance(error, driver.CUresult):
91-
_check_driver_error(error)
106+
return _check_driver_error(error)
92107
elif isinstance(error, runtime.cudaError_t):
93-
_check_runtime_error(error)
108+
return _check_runtime_error(error)
94109
elif isinstance(error, nvrtc.nvrtcResult):
95-
if error == nvrtc.nvrtcResult.NVRTC_SUCCESS:
96-
return
97-
err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}"
98-
if handle is not None:
99-
_, logsize = nvrtc.nvrtcGetProgramLogSize(handle)
100-
log = b" " * logsize
101-
_ = nvrtc.nvrtcGetProgramLog(handle, log)
102-
err += f", compilation log:\n\n{log.decode('utf-8', errors='backslashreplace')}"
103-
raise NVRTCError(err)
110+
return _check_nvrtc_error(error, handle=handle)
104111
else:
105112
raise RuntimeError(f"Unknown error type: {error}")
106113

107114

108115
def handle_return(tuple result, handle=None):
109116
_check_error(result[0], handle=handle)
110-
if len(result) == 1:
117+
cdef int out_len = len(result)
118+
if out_len == 1:
111119
return
112-
elif len(result) == 2:
120+
elif out_len == 2:
113121
return result[1]
114122
else:
115123
return result[1:]
@@ -144,7 +152,7 @@ def _handle_boolean_option(option: bool) -> str:
144152
return "true" if bool(option) else "false"
145153

146154

147-
def precondition(checker: Callable[..., None], what: str = "") -> Callable:
155+
def precondition(checker: Callable[..., None], str what="") -> Callable:
148156
"""
149157
A decorator that adds checks to ensure any preconditions are met.
150158

0 commit comments

Comments
 (0)