Skip to content

Commit 4ad4042

Browse files
authored
Merge pull request #261 from leofang/cluster
Add `cluster` to `LaunchConfig` to support thread block clusters on Hopper
2 parents f1267cd + 4b95ba4 commit 4ad4042

File tree

4 files changed

+107
-5
lines changed

4 files changed

+107
-5
lines changed

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional, Union
88

99
from cuda import cuda
10+
from cuda.core.experimental._device import Device
1011
from cuda.core.experimental._kernel_arg_handler import ParamHolder
1112
from cuda.core.experimental._module import Kernel
1213
from cuda.core.experimental._stream import Stream
@@ -38,10 +39,14 @@ class LaunchConfig:
3839
----------
3940
grid : Union[tuple, int]
4041
Collection of threads that will execute a kernel function.
42+
cluster : Union[tuple, int]
43+
Group of blocks (Thread Block Cluster) that will execute on the same
44+
GPU Processing Cluster (GPC). Blocks within a cluster have access to
45+
distributed shared memory and can be explicitly synchronized.
4146
block : Union[tuple, int]
4247
Group of threads (Thread Block) that will execute on the same
43-
multiprocessor. Threads within a thread blocks have access to
44-
shared memory and can be explicitly synchronized.
48+
streaming multiprocessor (SM). Threads within a thread blocks have
49+
access to shared memory and can be explicitly synchronized.
4550
stream : :obj:`Stream`
4651
The stream establishing the stream ordering semantic of a
4752
launch.
@@ -53,13 +58,22 @@ class LaunchConfig:
5358

5459
# TODO: expand LaunchConfig to include other attributes
5560
grid: Union[tuple, int] = None
61+
cluster: Union[tuple, int] = None
5662
block: Union[tuple, int] = None
5763
stream: Stream = None
5864
shmem_size: Optional[int] = None
5965

6066
def __post_init__(self):
67+
_lazy_init()
6168
self.grid = self._cast_to_3_tuple(self.grid)
6269
self.block = self._cast_to_3_tuple(self.block)
70+
# thread block clusters are supported starting H100
71+
if self.cluster is not None:
72+
if not _use_ex:
73+
raise CUDAError("thread block clusters require cuda.bindings & driver 11.8+")
74+
if Device().compute_capability < (9, 0):
75+
raise CUDAError("thread block clusters are not supported on devices with compute capability < 9.0")
76+
self.cluster = self._cast_to_3_tuple(self.cluster)
6377
# we handle "stream=None" in the launch API
6478
if self.stream is not None and not isinstance(self.stream, Stream):
6579
try:
@@ -69,8 +83,6 @@ def __post_init__(self):
6983
if self.shmem_size is None:
7084
self.shmem_size = 0
7185

72-
_lazy_init()
73-
7486
def _cast_to_3_tuple(self, cfg):
7587
if isinstance(cfg, int):
7688
if cfg < 1:
@@ -133,7 +145,15 @@ def launch(kernel, config, *kernel_args):
133145
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
134146
drv_cfg.hStream = config.stream.handle
135147
drv_cfg.sharedMemBytes = config.shmem_size
136-
drv_cfg.numAttrs = 0 # TODO
148+
attrs = [] # TODO: support more attributes
149+
if config.cluster:
150+
attr = cuda.CUlaunchAttribute()
151+
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
152+
dim = attr.value.clusterDim
153+
dim.x, dim.y, dim.z = config.cluster
154+
attrs.append(attr)
155+
drv_cfg.numAttrs = len(attrs)
156+
drv_cfg.attrs = attrs
137157
handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))
138158
else:
139159
# TODO: check if config has any unsupported attrs

cuda_core/docs/source/release/0.1.1-notes.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@ Released on Dec XX, 2024
1212
- Add a `cuda.core.experimental.system` module for querying system- or process- wide information.
1313
- Support TCC devices with a default synchronous memory resource to avoid the use of memory pools
1414

15+
## New features
16+
17+
- Add `LaunchConfig.cluster` to support thread block clusters on Hopper GPUs.
18+
19+
## Enchancements
20+
21+
- Ensure "ltoir" is a valid code type to `ObjectCode`.
22+
- Improve test coverage.
23+
- Enforce code formatting.
24+
25+
## Bug fixes
26+
27+
- Eliminate potential class destruction issues.
28+
- Fix circular import during handling a foreign CUDA stream.
1529

1630
## Limitations
1731

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4+
5+
import os
6+
import sys
7+
8+
from cuda.core.experimental import Device, LaunchConfig, Program, launch
9+
10+
# prepare include
11+
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
12+
if cuda_path is None:
13+
print("this demo requires a valid CUDA_PATH environment variable set", file=sys.stderr)
14+
sys.exit(0)
15+
cuda_include_path = os.path.join(cuda_path, "include")
16+
17+
# print cluster info using a kernel
18+
code = r"""
19+
#include <cooperative_groups.h>
20+
21+
namespace cg = cooperative_groups;
22+
23+
extern "C"
24+
__global__ void check_cluster_info() {
25+
auto g = cg::this_grid();
26+
auto b = cg::this_thread_block();
27+
if (g.cluster_rank() == 0 && g.block_rank() == 0 && g.thread_rank() == 0) {
28+
printf("grid dim: (%u, %u, %u)\n", g.dim_blocks().x, g.dim_blocks().y, g.dim_blocks().z);
29+
printf("cluster dim: (%u, %u, %u)\n", g.dim_clusters().x, g.dim_clusters().y, g.dim_clusters().z);
30+
printf("block dim: (%u, %u, %u)\n", b.dim_threads().x, b.dim_threads().y, b.dim_threads().z);
31+
}
32+
}
33+
"""
34+
35+
dev = Device()
36+
arch = dev.compute_capability
37+
if arch < (9, 0):
38+
print(
39+
"this demo requires compute capability >= 9.0 (since thread block cluster is a hardware feature)",
40+
file=sys.stderr,
41+
)
42+
sys.exit(0)
43+
arch = "".join(f"{i}" for i in arch)
44+
45+
# prepare program & compile kernel
46+
dev.set_current()
47+
prog = Program(code, code_type="c++")
48+
mod = prog.compile(
49+
target_type="cubin",
50+
# TODO: update this after NVIDIA/cuda-python#237 is merged
51+
options=(f"-arch=sm_{arch}", "-std=c++17", f"-I{cuda_include_path}"),
52+
)
53+
ker = mod.get_kernel("check_cluster_info")
54+
55+
# prepare launch config
56+
grid = 4
57+
cluster = 2
58+
block = 32
59+
config = LaunchConfig(grid=grid, cluster=cluster, block=block, stream=dev.default_stream)
60+
61+
# launch kernel on the default stream
62+
launch(ker, config)
63+
dev.sync()
64+
65+
print("done!")

cuda_core/tests/example_tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def run_example(samples_path, filename, env=None):
4242
break
4343
else:
4444
raise
45+
except SystemExit:
46+
# for samples that early return due to any missing requirements
47+
pytest.skip(f"skip {filename}")
4548
except Exception as e:
4649
msg = "\n"
4750
msg += f"Got error ({filename}):\n"

0 commit comments

Comments
 (0)