diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 68571ebc..894e2165 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -116,8 +116,9 @@ def inner(*args, **kwargs): def get_device_from_ctx(ctx_handle) -> int: """Get device ID from the given ctx.""" - prev_ctx = Device().context.handle - if ctx_handle != prev_ctx: + from cuda.core.experimental._device import Device # avoid circular import + prev_ctx = Device().context._handle + if int(ctx_handle) != int(prev_ctx): switch_context = True else: switch_context = False diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index e0a98c18..6e5acd47 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -71,6 +71,16 @@ def test_stream_context(): context = stream.context assert context is not None +def test_stream_from_foreign_stream(): + device = Device() + other_stream = device.create_stream(options=StreamOptions()) + stream = device.create_stream(obj=other_stream) + assert other_stream.handle == stream.handle + device = stream.device + assert isinstance(device, Device) + context = stream.context + assert context is not None + def test_stream_from_handle(): stream = Stream.from_handle(0) assert isinstance(stream, Stream)