Skip to content

Commit 6d037dc

Browse files
committed
Check NVRTC error correctly
1 parent 4f54bb5 commit 6d037dc

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

cuda_core/tests/test_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import pytest
77

88
try:
9-
from cuda.bindings import driver
9+
from cuda.bindings import nvrtc
1010
except ImportError:
11-
from cuda import cuda as driver
11+
from cuda import nvrtc
1212
from cuda.core.experimental import (
1313
CompleteOptions,
1414
DebugPrintOptions,
@@ -55,9 +55,9 @@ def _common_kernels_conditional():
5555
try:
5656
mod = prog.compile("cubin", name_expressions=("empty_kernel", "add_one", "set_handle", "loop_kernel"))
5757
except NVRTCError as e:
58-
with pytest.raises(RuntimeError, match='error: identifier "cudaGraphConditionalHandle" is undefined'):
58+
with pytest.raises(NVRTCError, match='error: identifier "cudaGraphConditionalHandle" is undefined'):
5959
raise e
60-
nvrtcVersion = handle_return(driver.nvrtcVersion())
60+
nvrtcVersion = handle_return(nvrtc.nvrtcVersion())
6161
pytest.skip(f"NVRTC version {nvrtcVersion} does not support conditionals")
6262
return mod
6363

0 commit comments

Comments
 (0)