Skip to content

Commit 0a8a30d

Browse files
Improve program checks (#394)
* squash * fix bug * remove todo * update the pr * whitespace change --------- Co-authored-by: Leo Fang <[email protected]>
1 parent 1e1148b commit 0a8a30d

File tree

4 files changed

+34
-15
lines changed

4 files changed

+34
-15
lines changed

cuda_core/cuda/core/experimental/_program.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import weakref
66
from dataclasses import dataclass
77
from typing import List, Optional, Tuple, Union
8+
from warnings import warn
89

910
from cuda.core.experimental._device import Device
1011
from cuda.core.experimental._linker import Linker, LinkerOptions
1112
from cuda.core.experimental._module import ObjectCode
1213
from cuda.core.experimental._utils import (
1314
_handle_boolean_option,
1415
check_or_create_options,
16+
driver,
1517
handle_return,
1618
is_nested_sequence,
1719
is_sequence,
@@ -378,6 +380,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
378380
raise TypeError("c++ Program expects code argument to be a string")
379381
# TODO: support pre-loaded headers & include names
380382
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
383+
381384
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
382385
self._backend = "nvrtc"
383386
self._linker = None
@@ -414,6 +417,11 @@ def close(self):
414417
self._linker.close()
415418
self._mnff.close()
416419

420+
def _can_load_generated_ptx(self):
421+
driver_ver = handle_return(driver.cuDriverGetVersion())
422+
nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion())
423+
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
424+
417425
def compile(self, target_type, name_expressions=(), logs=None):
418426
"""Compile the program with a specific compilation type.
419427
@@ -440,6 +448,13 @@ def compile(self, target_type, name_expressions=(), logs=None):
440448
raise NotImplementedError
441449

442450
if self._backend == "nvrtc":
451+
if target_type == "ptx" and not self._can_load_generated_ptx():
452+
warn(
453+
"The CUDA driver version is older than the backend version. "
454+
"The generated ptx will not be loadable by the current driver.",
455+
stacklevel=1,
456+
category=RuntimeWarning,
457+
)
443458
if name_expressions:
444459
for n in name_expressions:
445460
handle_return(

cuda_core/tests/conftest.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
import sys
1212

1313
try:
14-
from cuda.bindings import driver, nvrtc
14+
from cuda.bindings import driver
1515
except ImportError:
1616
from cuda import cuda as driver
17-
from cuda import nvrtc
1817
import pytest
1918

2019
from cuda.core.experimental import Device, _device
@@ -66,9 +65,3 @@ def clean_up_cffi_files():
6665
os.remove(f)
6766
except FileNotFoundError:
6867
pass # noqa: SIM105
69-
70-
71-
def can_load_generated_ptx():
72-
_, driver_ver = driver.cuDriverGetVersion()
73-
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
74-
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver

cuda_core/tests/test_module.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
# is strictly prohibited.
88

99

10+
import warnings
11+
1012
import pytest
11-
from conftest import can_load_generated_ptx
1213

1314
from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system
1415

@@ -40,10 +41,15 @@ def get_saxpy_kernel(init_cuda):
4041
return mod.get_kernel("saxpy<float>"), mod
4142

4243

43-
@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
4444
def test_get_kernel(init_cuda):
4545
kernel = """extern "C" __global__ void ABC() { }"""
46-
object_code = Program(kernel, "c++", options=ProgramOptions(relocatable_device_code=True)).compile("ptx")
46+
47+
with warnings.catch_warnings(record=True) as w:
48+
warnings.simplefilter("always")
49+
object_code = Program(kernel, "c++", options=ProgramOptions(relocatable_device_code=True)).compile("ptx")
50+
if any("The CUDA driver version is older than the backend version" in str(warning.message) for warning in w):
51+
pytest.skip("PTX version too new for current driver")
52+
4753
assert object_code._handle is None
4854
kernel = object_code.get_kernel("ABC")
4955
assert object_code._handle is not None

cuda_core/tests/test_program.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
# this software and related documentation outside the terms of the EULA
77
# is strictly prohibited.
88

9+
import warnings
10+
911
import pytest
10-
from conftest import can_load_generated_ptx
1112

1213
from cuda.core.experimental import _linker
1314
from cuda.core.experimental._module import Kernel, ObjectCode
@@ -100,13 +101,17 @@ def test_program_init_invalid_code_format():
100101
Program(code, "c++")
101102

102103

103-
# TODO: incorporate this check in Program
104104
# This is tested against the current device's arch
105-
@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
106105
def test_program_compile_valid_target_type(init_cuda):
107106
code = 'extern "C" __global__ void my_kernel() {}'
108107
program = Program(code, "c++")
109-
ptx_object_code = program.compile("ptx")
108+
109+
with warnings.catch_warnings(record=True) as w:
110+
warnings.simplefilter("always")
111+
ptx_object_code = program.compile("ptx")
112+
if any("The CUDA driver version is older than the backend version" in str(warning.message) for warning in w):
113+
pytest.skip("PTX version too new for current driver")
114+
110115
program = Program(ptx_object_code._module.decode(), "ptx")
111116
cubin_object_code = program.compile("cubin")
112117
ptx_kernel = ptx_object_code.get_kernel("my_kernel")

0 commit comments

Comments
 (0)