1717 _check_driver_error ,
1818 driver ,
1919 handle_return ,
20- precondition ,
2120 runtime ,
2221)
2322
@@ -1017,12 +1016,31 @@ def __new__(cls, device_id: Optional[int] = None):
10171016 except IndexError :
10181017 raise ValueError (f"device_id must be within [0, { len (devices )} ), got { device_id } " ) from None
10191018
1020- def _check_context_initialized (self , * args , ** kwargs ):
1019+ def _check_context_initialized (self ):
10211020 if not self ._has_inited :
10221021 raise CUDAError (
10231022 f"Device { self ._id } is not yet initialized, perhaps you forgot to call .set_current() first?"
10241023 )
10251024
1025+ def _get_current_context (self , check_consistency = False ) -> driver .CUcontext :
1026+ err , ctx = driver .cuCtxGetCurrent ()
1027+
1028+ # TODO: We want to just call this:
1029+ #_check_driver_error(err)
1030+ # but even the simplest success check causes 50-100 ns. Wait until we cythonize this file...
1031+ if ctx is None :
1032+ _check_driver_error (err )
1033+
1034+ if int (ctx ) == 0 :
1035+ raise CUDAError ("No context is bound to the calling CPU thread." )
1036+ if check_consistency :
1037+ err , dev = driver .cuCtxGetDevice ()
1038+ if err != _SUCCESS :
1039+ handle_return ((err ,))
1040+ if int (dev ) != self ._id :
1041+ raise CUDAError ("Internal error (current device is not equal to Device.device_id)" )
1042+ return ctx
1043+
10261044 @property
10271045 def device_id (self ) -> int :
10281046 """Return device ordinal."""
@@ -1083,7 +1101,6 @@ def compute_capability(self) -> ComputeCapability:
10831101 return cc
10841102
10851103 @property
1086- @precondition (_check_context_initialized )
10871104 def context (self ) -> Context :
10881105 """Return the current :obj:`~_context.Context` associated with this device.
10891106
@@ -1092,9 +1109,8 @@ def context(self) -> Context:
10921109 Device must be initialized.
10931110
10941111 """
1095- ctx = handle_return (driver .cuCtxGetCurrent ())
1096- if int (ctx ) == 0 :
1097- raise CUDAError ("No context is bound to the calling CPU thread." )
1112+ self ._check_context_initialized ()
1113+ ctx = self ._get_current_context (check_consistency = True )
10981114 return Context ._from_ctx (ctx , self ._id )
10991115
11001116 @property
@@ -1206,7 +1222,6 @@ def create_context(self, options: ContextOptions = None) -> Context:
12061222 """
12071223 raise NotImplementedError ("WIP: https://github.com/NVIDIA/cuda-python/issues/189" )
12081224
1209- @precondition (_check_context_initialized )
12101225 def create_stream (self , obj : Optional [IsStreamT ] = None , options : StreamOptions = None ) -> Stream :
12111226 """Create a Stream object.
12121227
@@ -1235,6 +1250,7 @@ def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions
12351250 Newly created stream object.
12361251
12371252 """
1253+ self ._check_context_initialized ()
12381254 return Stream ._init (obj = obj , options = options )
12391255
12401256 def create_event (self , options : Optional [EventOptions ] = None ) -> Event :
@@ -1255,12 +1271,10 @@ def create_event(self, options: Optional[EventOptions] = None) -> Event:
12551271 Newly created event object.
12561272
12571273 """
1258- ctx = driver .cuCtxGetCurrent ()[1 ]
1259- if int (ctx ) == 0 :
1260- raise CUDAError ("No context is bound to the calling CPU thread." )
1274+ self ._check_context_initialized ()
1275+ ctx = self ._get_current_context ()
12611276 return Event ._init (self ._id , ctx , options )
12621277
1263- @precondition (_check_context_initialized )
12641278 def allocate (self , size , stream : Optional [Stream ] = None ) -> Buffer :
12651279 """Allocate device memory from a specified stream.
12661280
@@ -1287,11 +1301,11 @@ def allocate(self, size, stream: Optional[Stream] = None) -> Buffer:
12871301 Newly created buffer object.
12881302
12891303 """
1304+ self ._check_context_initialized ()
12901305 if stream is None :
12911306 stream = default_stream ()
12921307 return self ._mr .allocate (size , stream )
12931308
1294- @precondition (_check_context_initialized )
12951309 def sync (self ):
12961310 """Synchronize the device.
12971311
@@ -1300,9 +1314,9 @@ def sync(self):
13001314 Device must be initialized.
13011315
13021316 """
1317+ self ._check_context_initialized ()
13031318 handle_return (runtime .cudaDeviceSynchronize ())
13041319
1305- @precondition (_check_context_initialized )
13061320 def create_graph_builder (self ) -> GraphBuilder :
13071321 """Create a new :obj:`~_graph.GraphBuilder` object.
13081322
@@ -1312,4 +1326,5 @@ def create_graph_builder(self) -> GraphBuilder:
13121326 Newly created graph builder object.
13131327
13141328 """
1329+ self ._check_context_initialized ()
13151330 return GraphBuilder ._init (stream = self .create_stream (), is_stream_owner = True )
0 commit comments