Skip to content

Commit a0ab9d8

Browse files
committed
add pytest support for tutorial gemm
1 parent 5016493 commit a0ab9d8

File tree

6 files changed

+273
-8
lines changed

6 files changed

+273
-8
lines changed

examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,15 @@ def kernel(
141141
# (bM, bN)
142142
gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
143143
thr_mma = tiled_mma.get_slice(0)
144-
# (MMA, MMA_M, MMA_K)
144+
# (MMA, MMA_M, MMA_K, RestK)
145145
tCgA = thr_mma.partition_A(gA)
146-
# (MMA, MMA_N, MMA_K)
146+
# (MMA, MMA_N, MMA_K, RestK)
147147
tCgB = thr_mma.partition_B(gB)
148148
# (MMA, MMA_M, MMA_N)
149149
tCgC = thr_mma.partition_C(gC)
150-
# (MMA, MMA_M, MMA_K)
150+
# (MMA, MMA_M, MMA_K, STAGE)
151151
tCrA = tiled_mma.make_fragment_A(sA)
152-
# (MMA, MMA_N, MMA_K)
152+
# (MMA, MMA_N, MMA_K, STAGE)
153153
tCrB = tiled_mma.make_fragment_B(sB)
154154
# (MMA, MMA_M, MMA_N)
155155
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])

examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,15 @@ def kernel(
174174
# (bM, bN)
175175
gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
176176
thr_mma = tiled_mma.get_slice(mma_coord_vmnk[0])
177-
# (MMA, MMA_M, MMA_K)
177+
# (MMA, MMA_M, MMA_K, RestK)
178178
tCgA = thr_mma.partition_A(gA)
179-
# (MMA, MMA_N, MMA_K)
179+
# (MMA, MMA_N, MMA_K, RestK)
180180
tCgB = thr_mma.partition_B(gB)
181181
# (MMA, MMA_M, MMA_N)
182182
tCgC = thr_mma.partition_C(gC)
183-
# (MMA, MMA_M, MMA_K)
183+
# (MMA, MMA_M, MMA_K, STAGE)
184184
tCrA = tiled_mma.make_fragment_A(sA)
185-
# (MMA, MMA_N, MMA_K)
185+
# (MMA, MMA_N, MMA_K, STAGE)
186186
tCrB = tiled_mma.make_fragment_B(sB)
187187
# (MMA, MMA_M, MMA_N)
188188
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])

test/examples/CuTeDSL/conftest.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
7+
# 1. Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
10+
# 2. Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
14+
# 3. Neither the name of the copyright holder nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import pytest
30+
import logging
31+
import sys
32+
from pathlib import Path
33+
34+
import torch
35+
import numpy as np
36+
37+
38+
project_root = Path(__file__).resolve().parent.parent.parent.parent
39+
example_path = project_root / "examples" / "python" / "CuTeDSL"
40+
utils_path = project_root / "test" / "utils"
41+
sys.path.append(str(example_path))
42+
sys.path.append(str(utils_path))
43+
44+
# The helper class to prevent modification of sys.path from test files
45+
# Only allow modification of sys.path from pytest monkeypatch API calls
46+
class ImmutableSysPath(list):
47+
mutating_methods = {
48+
"append",
49+
"extend",
50+
"insert",
51+
"remove",
52+
"pop",
53+
"clear",
54+
"reverse",
55+
"sort",
56+
"__setitem__",
57+
"__delitem__",
58+
}
59+
60+
for mtd in mutating_methods:
61+
def mutating_method(self, *args, mtd=mtd, **kwargs):
62+
frame = sys._getframe().f_back
63+
if (
64+
frame
65+
and hasattr(frame, "f_locals")
66+
and "__file__" in frame.f_locals
67+
and frame.f_locals["__file__"].startswith(str(project_root))
68+
):
69+
err_msg = (
70+
"Modification of sys.path is forbidden in test file! "
71+
"Please use pytest monkeypatch.syspath_prepend(...) instead."
72+
)
73+
raise RuntimeError(err_msg)
74+
else:
75+
return getattr(super(), mtd)(*args, **kwargs)
76+
77+
locals()[mtd] = mutating_method
78+
79+
def __init__(self, initial=None):
80+
if initial is None:
81+
initial = []
82+
super().__init__(initial)
83+
84+
85+
sys.path = ImmutableSysPath(list(sys.path))
86+
87+
pytest_plugins = ["test_sharding"]
88+
89+
def pytest_addoption(parser):
90+
parser.addoption(
91+
"--sample-interval",
92+
action="store",
93+
type=int,
94+
default=4,
95+
help="If value x is provided, then 1 / x of random picked tests will be run",
96+
)
97+
98+
99+
@pytest.fixture
100+
def sample_interval(request):
101+
return request.config.getoption("--sample-interval")
102+
103+
104+
# Removes all StreamHandlers from loggers at the end of test session
105+
# This prevents errors when atexit-registered functions try to use loggers
106+
# whose handlers have already been closed during pytest teardown
107+
@pytest.fixture(scope="session", autouse=True)
108+
def cleanup_logging_handlers():
109+
try:
110+
yield
111+
finally:
112+
loggers = [logging.getLogger()] + list(
113+
logging.Logger.manager.loggerDict.values()
114+
)
115+
for logger in loggers:
116+
handlers = getattr(logger, "handlers", [])
117+
for handler in handlers:
118+
if isinstance(handler, logging.StreamHandler):
119+
logger.removeHandler(handler)
120+
121+
122+
@pytest.fixture(autouse=True, scope="module")
123+
def torch_sanity_check():
124+
if not torch.cuda.is_available():
125+
raise RuntimeError("GPU is required to run example tests!")
126+
127+
128+
@pytest.fixture(autouse=True)
129+
def torch_empty_cache():
130+
"""
131+
Automatically empty the torch CUDA cache at the end of each test, to reduce risk of OOM errors.
132+
"""
133+
yield
134+
if torch.cuda.is_available():
135+
torch.cuda.empty_cache()
136+
137+
138+
@pytest.fixture(autouse=True)
139+
def torch_seed(request):
140+
if torch.cuda.is_available():
141+
seed = hash(request.node.nodeid) % 2**32
142+
torch.manual_seed(seed)
143+
144+
145+
@pytest.fixture(autouse=True)
146+
def numpy_seed(request):
147+
seed = hash(request.node.nodeid) % 2**32
148+
np.random.seed(seed)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
7+
# 1. Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
10+
# 2. Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
14+
# 3. Neither the name of the copyright holder nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
def pytest_configure(config):
30+
config.default_SMs[__file__] = "100f"
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
7+
# 1. Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
10+
# 2. Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
14+
# 3. Neither the name of the copyright holder nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
from blackwell.tutorial_gemm import fp16_gemm_0
30+
from blackwell.tutorial_gemm import fp16_gemm_1
31+
32+
import pytest
33+
from typing import Tuple
34+
35+
36+
@pytest.mark.parametrize(
37+
"mnk",
38+
[(512, 512, 256)],
39+
)
40+
@pytest.mark.parametrize("tolerance", [1e-01])
41+
def test_fp16_gemm_0(
42+
mnk: Tuple[int, int, int],
43+
tolerance: float,
44+
):
45+
fp16_gemm_0.run_dense_gemm(mnk, tolerance)
46+
47+
48+
@pytest.mark.parametrize(
49+
"mnk",
50+
[(512, 512, 256)],
51+
)
52+
@pytest.mark.parametrize("tolerance", [1e-01])
53+
def test_fp16_gemm_1(
54+
mnk: Tuple[int, int, int],
55+
tolerance: float,
56+
):
57+
fp16_gemm_1.run_dense_gemm(mnk, tolerance)

test/utils/device_info.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
def _get_device_compute_capability():
2+
try:
3+
import cuda.bindings.driver as drv
4+
from cuda.bindings.driver import CUdevice_attribute as dev_attr
5+
6+
def drv_api(api_name, *args):
7+
ret_code, *result = getattr(drv, api_name)(*args)
8+
if ret_code:
9+
raise ValueError(f"CUDA error: {ret_code}")
10+
return result[0] if len(result) == 1 else result
11+
12+
drv_api("cuInit", 0)
13+
device = drv_api("cuDeviceGet", 0)
14+
major = drv_api(
15+
"cuDeviceGetAttribute",
16+
dev_attr.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
17+
device,
18+
)
19+
minor = drv_api(
20+
"cuDeviceGetAttribute",
21+
dev_attr.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
22+
device,
23+
)
24+
return f"{major}{minor}"
25+
except Exception as e:
26+
print(f"Failed to get CUDA compute capability: {e}")
27+
return None
28+
29+
30+
compute_capability = _get_device_compute_capability()

0 commit comments

Comments
 (0)