From 90fa8f666ce81721e6945a32a44237be88aecce1 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 30 Oct 2024 15:11:23 -0700 Subject: [PATCH 1/7] remove the method and its references --- cuda_core/cuda/core/experimental/_stream.py | 9 --------- cuda_core/cuda/core/experimental/_utils.py | 17 ----------------- 2 files changed, 26 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_stream.py b/cuda_core/cuda/core/experimental/_stream.py index 95f8ec50..0e8f8141 100644 --- a/cuda_core/cuda/core/experimental/_stream.py +++ b/cuda_core/cuda/core/experimental/_stream.py @@ -14,7 +14,6 @@ from cuda.core.experimental._context import Context from cuda.core.experimental._event import Event, EventOptions from cuda.core.experimental._utils import check_or_create_options -from cuda.core.experimental._utils import get_device_from_ctx from cuda.core.experimental._utils import handle_return @@ -182,12 +181,6 @@ def device(self) -> Device: # Stream.context, in cases where a different CUDA context is set # current after a stream was created. from cuda.core.experimental._device import Device # avoid circular import - if self._device_id is None: - # Get the stream context first - if self._ctx_handle is None: - self._ctx_handle = handle_return( - cuda.cuStreamGetCtx(self._handle)) - self._device_id = get_device_from_ctx(self._ctx_handle) return Device(self._device_id) @property @@ -197,8 +190,6 @@ def context(self) -> Context: if self._ctx_handle is None: self._ctx_handle = handle_return( cuda.cuStreamGetCtx(self._handle)) - if self._device_id is None: - self._device_id = get_device_from_ctx(self._ctx_handle) return Context._from_ctx(self._ctx_handle, self._device_id) @staticmethod diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 68571ebc..f19e637f 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -112,20 +112,3 @@ def inner(*args, **kwargs): return inner return outer - - -def get_device_from_ctx(ctx_handle) -> int: - """Get device ID from the given ctx.""" - prev_ctx = Device().context.handle - if ctx_handle != prev_ctx: - switch_context = True - else: - switch_context = False - if switch_context: - assert prev_ctx == handle_return(cuda.cuCtxPopCurrent()) - handle_return(cuda.cuCtxPushCurrent(ctx_handle)) - device_id = int(handle_return(cuda.cuCtxGetDevice())) - if switch_context: - assert ctx_handle == handle_return(cuda.cuCtxPopCurrent()) - handle_return(cuda.cuCtxPushCurrent(prev_ctx)) - return device_id From 0579f3a628270de7eceed213eb1a1aebb083b376 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Thu, 31 Oct 2024 10:34:33 -0700 Subject: [PATCH 2/7] add failing test --- cuda_core/tests/test_stream.py | 42 +++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index e0a98c18..780609a8 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -6,28 +6,40 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. -from cuda.core.experimental._stream import Stream, StreamOptions, LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM, default_stream -from cuda.core.experimental._event import Event, EventOptions -from cuda.core.experimental._device import Device import pytest +from cuda.core.experimental._device import Device +from cuda.core.experimental._event import Event +from cuda.core.experimental._stream import ( + LEGACY_DEFAULT_STREAM, + PER_THREAD_DEFAULT_STREAM, + Stream, + StreamOptions, + default_stream, +) + + def test_stream_init(): with pytest.raises(NotImplementedError): Stream() + def test_stream_init_with_options(): stream = Stream._init(options=StreamOptions(nonblocking=True, priority=0)) assert stream.is_nonblocking is True assert stream.priority == 0 + def test_stream_handle(): stream = Stream._init(options=StreamOptions()) assert isinstance(stream.handle, int) + def test_stream_is_nonblocking(): stream = Stream._init(options=StreamOptions(nonblocking=True)) assert stream.is_nonblocking is True + def test_stream_priority(): stream = Stream._init(options=StreamOptions(priority=0)) assert stream.priority == 0 @@ -36,51 +48,75 @@ def test_stream_priority(): with pytest.raises(ValueError): stream = Stream._init(options=StreamOptions(priority=1)) + def test_stream_sync(): stream = Stream._init(options=StreamOptions()) stream.sync() # Should not raise any exceptions + def test_stream_record(): stream = Stream._init(options=StreamOptions()) event = stream.record() assert isinstance(event, Event) + def test_stream_record_invalid_event(): stream = Stream._init(options=StreamOptions()) with pytest.raises(TypeError): stream.record(event="invalid_event") + def test_stream_wait_event(): stream = Stream._init(options=StreamOptions()) event = Event._init() stream.record(event) stream.wait(event) # Should not raise any exceptions + def test_stream_wait_invalid_event(): stream = Stream._init(options=StreamOptions()) with pytest.raises(ValueError): stream.wait(event_or_stream="invalid_event") + def test_stream_device(): stream = Stream._init(options=StreamOptions()) device = stream.device assert isinstance(device, Device) + def test_stream_context(): stream = Stream._init(options=StreamOptions()) context = stream.context assert context is not None + def test_stream_from_handle(): stream = Stream.from_handle(0) assert isinstance(stream, Stream) +def test_stream_device_with_foreign_stream(): + device = Device() + other_stream = Stream._init(options=StreamOptions()) + stream = device.create_stream(obj=other_stream) + device = stream.device + assert isinstance(device, Device) + +def test_stream_context_with_foreign_stream(): + device = Device() + other_stream = Stream._init(options=StreamOptions()) + stream = device.create_stream(obj=other_stream) + context = stream.context + assert context is not None + def test_legacy_default_stream(): assert isinstance(LEGACY_DEFAULT_STREAM, Stream) + def test_per_thread_default_stream(): assert isinstance(PER_THREAD_DEFAULT_STREAM, Stream) + def test_default_stream(): stream = default_stream() assert isinstance(stream, Stream) From ee55aa513d1cece68bbb1e2374e7c8d9bcf6b0bd Mon Sep 17 00:00:00 2001 From: ksimpson Date: Thu, 31 Oct 2024 10:49:20 -0700 Subject: [PATCH 3/7] fix circular import issue --- cuda_core/cuda/core/experimental/_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index f19e637f..4d8be372 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -112,3 +112,21 @@ def inner(*args, **kwargs): return inner return outer + + +def get_device_from_ctx(ctx_handle) -> int: + """Get device ID from the given ctx.""" + from cuda.core.experimental._device import Device # avoid circular import + prev_ctx = Device().context.handle + if ctx_handle != prev_ctx: + switch_context = True + else: + switch_context = False + if switch_context: + assert prev_ctx == handle_return(cuda.cuCtxPopCurrent()) + handle_return(cuda.cuCtxPushCurrent(ctx_handle)) + device_id = int(handle_return(cuda.cuCtxGetDevice())) + if switch_context: + assert ctx_handle == handle_return(cuda.cuCtxPopCurrent()) + handle_return(cuda.cuCtxPushCurrent(prev_ctx)) + return device_id From 72deeb6f13a0e4e91137ecfb59b2fbb432cf9e90 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Thu, 31 Oct 2024 11:16:00 -0700 Subject: [PATCH 4/7] fix a couple little things in _utils method --- cuda_core/cuda/core/experimental/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 4d8be372..894e2165 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -117,8 +117,8 @@ def inner(*args, **kwargs): def get_device_from_ctx(ctx_handle) -> int: """Get device ID from the given ctx.""" from cuda.core.experimental._device import Device # avoid circular import - prev_ctx = Device().context.handle - if ctx_handle != prev_ctx: + prev_ctx = Device().context._handle + if int(ctx_handle) != int(prev_ctx): switch_context = True else: switch_context = False From 874261d7248cbab555ede4049a088b78d0eac88a Mon Sep 17 00:00:00 2001 From: ksimpson Date: Thu, 31 Oct 2024 11:16:51 -0700 Subject: [PATCH 5/7] revert stream changes --- cuda_core/cuda/core/experimental/_stream.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cuda_core/cuda/core/experimental/_stream.py b/cuda_core/cuda/core/experimental/_stream.py index 0e8f8141..95f8ec50 100644 --- a/cuda_core/cuda/core/experimental/_stream.py +++ b/cuda_core/cuda/core/experimental/_stream.py @@ -14,6 +14,7 @@ from cuda.core.experimental._context import Context from cuda.core.experimental._event import Event, EventOptions from cuda.core.experimental._utils import check_or_create_options +from cuda.core.experimental._utils import get_device_from_ctx from cuda.core.experimental._utils import handle_return @@ -181,6 +182,12 @@ def device(self) -> Device: # Stream.context, in cases where a different CUDA context is set # current after a stream was created. from cuda.core.experimental._device import Device # avoid circular import + if self._device_id is None: + # Get the stream context first + if self._ctx_handle is None: + self._ctx_handle = handle_return( + cuda.cuStreamGetCtx(self._handle)) + self._device_id = get_device_from_ctx(self._ctx_handle) return Device(self._device_id) @property @@ -190,6 +197,8 @@ def context(self) -> Context: if self._ctx_handle is None: self._ctx_handle = handle_return( cuda.cuStreamGetCtx(self._handle)) + if self._device_id is None: + self._device_id = get_device_from_ctx(self._ctx_handle) return Context._from_ctx(self._ctx_handle, self._device_id) @staticmethod From 0c4d86f4de1bfed48d65fb11420eb49ae0c36642 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Thu, 31 Oct 2024 11:21:16 -0700 Subject: [PATCH 6/7] revert whitespaces changes.. leave them for the ruff lint review --- cuda_core/tests/test_stream.py | 36 +++++++--------------------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index 780609a8..2ea99816 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -6,40 +6,28 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. -import pytest - +from cuda.core.experimental._stream import Stream, StreamOptions, LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM, default_stream +from cuda.core.experimental._event import Event, EventOptions from cuda.core.experimental._device import Device -from cuda.core.experimental._event import Event -from cuda.core.experimental._stream import ( - LEGACY_DEFAULT_STREAM, - PER_THREAD_DEFAULT_STREAM, - Stream, - StreamOptions, - default_stream, -) - +import pytest def test_stream_init(): with pytest.raises(NotImplementedError): Stream() - def test_stream_init_with_options(): stream = Stream._init(options=StreamOptions(nonblocking=True, priority=0)) assert stream.is_nonblocking is True assert stream.priority == 0 - def test_stream_handle(): stream = Stream._init(options=StreamOptions()) assert isinstance(stream.handle, int) - def test_stream_is_nonblocking(): stream = Stream._init(options=StreamOptions(nonblocking=True)) assert stream.is_nonblocking is True - def test_stream_priority(): stream = Stream._init(options=StreamOptions(priority=0)) assert stream.priority == 0 @@ -48,53 +36,41 @@ def test_stream_priority(): with pytest.raises(ValueError): stream = Stream._init(options=StreamOptions(priority=1)) - def test_stream_sync(): stream = Stream._init(options=StreamOptions()) stream.sync() # Should not raise any exceptions - def test_stream_record(): stream = Stream._init(options=StreamOptions()) event = stream.record() assert isinstance(event, Event) - def test_stream_record_invalid_event(): stream = Stream._init(options=StreamOptions()) with pytest.raises(TypeError): stream.record(event="invalid_event") - def test_stream_wait_event(): stream = Stream._init(options=StreamOptions()) event = Event._init() stream.record(event) stream.wait(event) # Should not raise any exceptions - def test_stream_wait_invalid_event(): stream = Stream._init(options=StreamOptions()) with pytest.raises(ValueError): stream.wait(event_or_stream="invalid_event") - def test_stream_device(): stream = Stream._init(options=StreamOptions()) device = stream.device assert isinstance(device, Device) - def test_stream_context(): stream = Stream._init(options=StreamOptions()) context = stream.context assert context is not None - -def test_stream_from_handle(): - stream = Stream.from_handle(0) - assert isinstance(stream, Stream) - def test_stream_device_with_foreign_stream(): device = Device() other_stream = Stream._init(options=StreamOptions()) @@ -108,15 +84,17 @@ def test_stream_context_with_foreign_stream(): stream = device.create_stream(obj=other_stream) context = stream.context assert context is not None + +def test_stream_from_handle(): + stream = Stream.from_handle(0) + assert isinstance(stream, Stream) def test_legacy_default_stream(): assert isinstance(LEGACY_DEFAULT_STREAM, Stream) - def test_per_thread_default_stream(): assert isinstance(PER_THREAD_DEFAULT_STREAM, Stream) - def test_default_stream(): stream = default_stream() assert isinstance(stream, Stream) From 925a42de2d1ea5bf78ce3a6c72cf4be4b3fc7a75 Mon Sep 17 00:00:00 2001 From: ksimpson-work Date: Mon, 4 Nov 2024 09:23:54 -0800 Subject: [PATCH 7/7] Update cuda_core/tests/test_stream.py Co-authored-by: Leo Fang --- cuda_core/tests/test_stream.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index 2ea99816..6e5acd47 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -71,17 +71,13 @@ def test_stream_context(): context = stream.context assert context is not None -def test_stream_device_with_foreign_stream(): +def test_stream_from_foreign_stream(): device = Device() - other_stream = Stream._init(options=StreamOptions()) + other_stream = device.create_stream(options=StreamOptions()) stream = device.create_stream(obj=other_stream) + assert other_stream.handle == stream.handle device = stream.device assert isinstance(device, Device) - -def test_stream_context_with_foreign_stream(): - device = Device() - other_stream = Stream._init(options=StreamOptions()) - stream = device.create_stream(obj=other_stream) context = stream.context assert context is not None