Skip to content

Commit fdbad93

Browse files
Merge branch 'main' into feature/occupancy
2 parents 5968ff0 + ed884c9 commit fdbad93

File tree

5 files changed

+43
-4
lines changed

5 files changed

+43
-4
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ def create_event(self, options: Optional[EventOptions] = None) -> Event:
12281228
Newly created event object.
12291229
12301230
"""
1231-
return Event._init(options)
1231+
return Event._init(self._id, self.context._handle, options)
12321232

12331233
@precondition(_check_context_initialized)
12341234
def allocate(self, size, stream=None) -> Buffer:

cuda_core/cuda/core/experimental/_event.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dataclasses import dataclass
99
from typing import TYPE_CHECKING, Optional
1010

11+
from cuda.core.experimental._context import Context
1112
from cuda.core.experimental._utils.cuda_utils import (
1213
CUDAError,
1314
check_or_create_options,
@@ -20,6 +21,7 @@
2021

2122
if TYPE_CHECKING:
2223
import cuda.bindings
24+
from cuda.core.experimental._device import Device
2325

2426

2527
@dataclass
@@ -91,10 +93,10 @@ def close(self):
9193
def __new__(self, *args, **kwargs):
9294
raise RuntimeError("Event objects cannot be instantiated directly. Please use Stream APIs (record).")
9395

94-
__slots__ = ("__weakref__", "_mnff", "_timing_disabled", "_busy_waited")
96+
__slots__ = ("__weakref__", "_mnff", "_timing_disabled", "_busy_waited", "_device_id", "_ctx_handle")
9597

9698
@classmethod
97-
def _init(cls, options: Optional[EventOptions] = None):
99+
def _init(cls, device_id: int, ctx_handle: Context, options: Optional[EventOptions] = None):
98100
self = super().__new__(cls)
99101
self._mnff = Event._MembersNeededForFinalize(self, None)
100102

@@ -111,6 +113,8 @@ def _init(cls, options: Optional[EventOptions] = None):
111113
if options.support_ipc:
112114
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103")
113115
self._mnff.handle = handle_return(driver.cuEventCreate(flags))
116+
self._device_id = device_id
117+
self._ctx_handle = ctx_handle
114118
return self
115119

116120
def close(self):
@@ -198,3 +202,24 @@ def handle(self) -> cuda.bindings.driver.CUevent:
198202
handle, call ``int(Event.handle)``.
199203
"""
200204
return self._mnff.handle
205+
206+
@property
207+
def device(self) -> Device:
208+
"""Return the :obj:`~_device.Device` singleton associated with this event.
209+
210+
Note
211+
----
212+
The current context on the device may differ from this
213+
event's context. This case occurs when a different CUDA
214+
context is set current after a event is created.
215+
216+
"""
217+
218+
from cuda.core.experimental._device import Device # avoid circular import
219+
220+
return Device(self._device_id)
221+
222+
@property
223+
def context(self) -> Context:
224+
"""Return the :obj:`~_context.Context` associated with this event."""
225+
return Context._from_ctx(self._ctx_handle, self._device_id)

cuda_core/cuda/core/experimental/_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def record(self, event: Event = None, options: EventOptions = None) -> Event:
244244
# on the stream. Event flags such as disabling timing, nonblocking,
245245
# and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions.
246246
if event is None:
247-
event = Event._init(options)
247+
event = Event._init(self._device_id, self._ctx_handle, options)
248248
assert_type(event, Event)
249249
handle_return(driver.cuEventRecord(event.handle, self._mnff.handle))
250250
return event

cuda_core/docs/source/release/0.3.0-notes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,5 @@ New examples
2828

2929
Fixes and enhancements
3030
----------------------
31+
32+
- An :class:`Event` can now be used to look up its corresponding device and context using the ``.device`` and ``.context`` attributes respectively.

cuda_core/tests/test_event.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,15 @@ def test_error_timing_incomplete():
169169
arr[0] = 1
170170
event3.sync()
171171
event3 - event1 # this should work
172+
173+
174+
def test_event_device(init_cuda):
175+
device = Device()
176+
event = device.create_event(options=EventOptions())
177+
assert event.device is device
178+
179+
180+
def test_event_context(init_cuda):
181+
event = Device().create_event(options=EventOptions())
182+
context = event.context
183+
assert context is not None

0 commit comments

Comments
 (0)