Skip to content

Commit 5a1667b

Browse files
committed
Check binding version in addition to driver version
1 parent 2512b84 commit 5a1667b

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

cuda_core/cuda/core/experimental/_graph.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def _lazy_init():
2525
if _inited:
2626
return
2727

28-
global _py_major_ver, _driver_ver
28+
global _py_major_minor, _driver_ver
2929
# binding availability depends on cuda-python version
30-
_py_major_ver, _ = get_binding_version()
30+
_py_major_minor = get_binding_version()
3131
_driver_ver = handle_return(driver.cuDriverGetVersion())
3232
_inited = True
3333

@@ -276,7 +276,7 @@ def complete(self, options: Optional[CompleteOptions] = None) -> Graph:
276276
if not self._building_ended:
277277
raise RuntimeError("Graph has not finished building.")
278278

279-
if _driver_ver < 12000:
279+
if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
280280
flags = 0
281281
if options.auto_free_on_launch:
282282
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
@@ -468,6 +468,8 @@ def create_conditional_handle(self, default_value=None) -> int:
468468
"""
469469
if _driver_ver < 12030:
470470
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles")
471+
if _py_major_minor < (12, 3):
472+
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles")
471473
if default_value is not None:
472474
flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT
473475
else:
@@ -539,6 +541,8 @@ def if_cond(self, handle: int) -> GraphBuilder:
539541
"""
540542
if _driver_ver < 12030:
541543
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if")
544+
if _py_major_minor < (12, 3):
545+
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if")
542546
node_params = driver.CUgraphNodeParams()
543547
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
544548
node_params.conditional.handle = handle
@@ -568,6 +572,8 @@ def if_else(self, handle: int) -> Tuple[GraphBuilder, GraphBuilder]:
568572
"""
569573
if _driver_ver < 12080:
570574
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else")
575+
if _py_major_minor < (12, 8):
576+
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else")
571577
node_params = driver.CUgraphNodeParams()
572578
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
573579
node_params.conditional.handle = handle
@@ -600,6 +606,8 @@ def switch(self, handle: int, count: int) -> Tuple[GraphBuilder, ...]:
600606
"""
601607
if _driver_ver < 12080:
602608
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch")
609+
if _py_major_minor < (12, 8):
610+
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch")
603611
node_params = driver.CUgraphNodeParams()
604612
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
605613
node_params.conditional.handle = handle
@@ -629,6 +637,8 @@ def while_loop(self, handle: int) -> GraphBuilder:
629637
"""
630638
if _driver_ver < 12030:
631639
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop")
640+
if _py_major_minor < (12, 3):
641+
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop")
632642
node_params = driver.CUgraphNodeParams()
633643
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
634644
node_params.conditional.handle = handle
@@ -741,10 +751,10 @@ def launch_graph(parent_graph: GraphBuilder, child_graph: GraphBuilder):
741751
742752
"""
743753

744-
if _driver_ver < 12000 or _py_major_ver < 12:
754+
if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
745755
raise NotImplementedError(
746756
f"Launching child graphs is not implemented for versions older than CUDA 12."
747-
f"Found driver version is {_driver_ver} and binding major version is {_py_major_ver}"
757+
f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}"
748758
)
749759

750760
if not child_graph._building_ended:

0 commit comments

Comments
 (0)