@@ -25,9 +25,9 @@ def _lazy_init():
2525 if _inited :
2626 return
2727
28- global _py_major_ver , _driver_ver
28+ global _py_major_minor , _driver_ver
2929 # binding availability depends on cuda-python version
30- _py_major_ver , _ = get_binding_version ()
30+ _py_major_minor = get_binding_version ()
3131 _driver_ver = handle_return (driver .cuDriverGetVersion ())
3232 _inited = True
3333
@@ -276,7 +276,7 @@ def complete(self, options: Optional[CompleteOptions] = None) -> Graph:
276276 if not self ._building_ended :
277277 raise RuntimeError ("Graph has not finished building." )
278278
279- if _driver_ver < 12000 :
279+ if ( _driver_ver < 12000 ) or ( _py_major_minor < ( 12 , 0 )) :
280280 flags = 0
281281 if options .auto_free_on_launch :
282282 flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
@@ -468,6 +468,8 @@ def create_conditional_handle(self, default_value=None) -> int:
468468 """
469469 if _driver_ver < 12030 :
470470 raise RuntimeError (f"Driver version { _driver_ver } does not support conditional handles" )
471+ if _py_major_minor < (12 , 3 ):
472+ raise RuntimeError (f"Binding version { _py_major_minor } does not support conditional handles" )
471473 if default_value is not None :
472474 flags = driver .CU_GRAPH_COND_ASSIGN_DEFAULT
473475 else :
@@ -539,6 +541,8 @@ def if_cond(self, handle: int) -> GraphBuilder:
539541 """
540542 if _driver_ver < 12030 :
541543 raise RuntimeError (f"Driver version { _driver_ver } does not support conditional if" )
544+ if _py_major_minor < (12 , 3 ):
545+ raise RuntimeError (f"Binding version { _py_major_minor } does not support conditional if" )
542546 node_params = driver .CUgraphNodeParams ()
543547 node_params .type = driver .CUgraphNodeType .CU_GRAPH_NODE_TYPE_CONDITIONAL
544548 node_params .conditional .handle = handle
@@ -568,6 +572,8 @@ def if_else(self, handle: int) -> Tuple[GraphBuilder, GraphBuilder]:
568572 """
569573 if _driver_ver < 12080 :
570574 raise RuntimeError (f"Driver version { _driver_ver } does not support conditional if-else" )
575+ if _py_major_minor < (12 , 8 ):
576+ raise RuntimeError (f"Binding version { _py_major_minor } does not support conditional if-else" )
571577 node_params = driver .CUgraphNodeParams ()
572578 node_params .type = driver .CUgraphNodeType .CU_GRAPH_NODE_TYPE_CONDITIONAL
573579 node_params .conditional .handle = handle
@@ -600,6 +606,8 @@ def switch(self, handle: int, count: int) -> Tuple[GraphBuilder, ...]:
600606 """
601607 if _driver_ver < 12080 :
602608 raise RuntimeError (f"Driver version { _driver_ver } does not support conditional switch" )
609+ if _py_major_minor < (12 , 8 ):
610+ raise RuntimeError (f"Binding version { _py_major_minor } does not support conditional switch" )
603611 node_params = driver .CUgraphNodeParams ()
604612 node_params .type = driver .CUgraphNodeType .CU_GRAPH_NODE_TYPE_CONDITIONAL
605613 node_params .conditional .handle = handle
@@ -629,6 +637,8 @@ def while_loop(self, handle: int) -> GraphBuilder:
629637 """
630638 if _driver_ver < 12030 :
631639 raise RuntimeError (f"Driver version { _driver_ver } does not support conditional while loop" )
640+ if _py_major_minor < (12 , 3 ):
641+ raise RuntimeError (f"Binding version { _py_major_minor } does not support conditional while loop" )
632642 node_params = driver .CUgraphNodeParams ()
633643 node_params .type = driver .CUgraphNodeType .CU_GRAPH_NODE_TYPE_CONDITIONAL
634644 node_params .conditional .handle = handle
@@ -741,10 +751,10 @@ def launch_graph(parent_graph: GraphBuilder, child_graph: GraphBuilder):
741751
742752 """
743753
744- if _driver_ver < 12000 or _py_major_ver < 12 :
754+ if ( _driver_ver < 12000 ) or ( _py_major_minor < ( 12 , 0 )) :
745755 raise NotImplementedError (
746756 f"Launching child graphs is not implemented for versions older than CUDA 12."
747- f"Found driver version is { _driver_ver } and binding major version is { _py_major_ver } "
757+ f"Found driver version is { _driver_ver } and binding version is { _py_major_minor } "
748758 )
749759
750760 if not child_graph ._building_ended :
0 commit comments