Skip to content

Commit ff322ec

Browse files
Add tests for cluster-related occupancy descriptors
1 parent 9679e0e commit ff322ec

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

cuda_core/tests/test_module.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,32 @@ def test_saxpy_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_ker
296296
smem_size = kernel.occupancy.available_dynamic_shared_memory_per_block(num_blocks_per_sm, block_size)
297297
assert smem_size <= dev_props.max_shared_memory_per_block
298298
assert num_blocks_per_sm * smem_size <= dev_props.max_shared_memory_per_multiprocessor
299+
300+
301+
@pytest.mark.parametrize("cluster", [None, 2])
302+
def test_saxpy_occupancy_max_active_clusters(get_saxpy_kernel, cluster):
303+
kernel, _ = get_saxpy_kernel
304+
dev = Device()
305+
if (cluster) and (dev.compute_capability < (9, 0)):
306+
pytest.skip("Device with compute capability 90 or higher is required for cluster support")
307+
launch_config = cuda.core.experimental.LaunchConfig(grid=128, block=64, cluster=cluster)
308+
query_fn = kernel.occupancy.max_active_clusters
309+
max_active_clusters = query_fn(launch_config)
310+
assert isinstance(max_active_clusters, int)
311+
assert max_active_clusters >= 0
312+
max_active_clusters = query_fn(launch_config, stream=dev.default_stream)
313+
assert isinstance(max_active_clusters, int)
314+
assert max_active_clusters >= 0
315+
316+
317+
def test_saxpy_occupancy_max_potential_cluster_size(get_saxpy_kernel):
318+
kernel, _ = get_saxpy_kernel
319+
dev = Device()
320+
launch_config = cuda.core.experimental.LaunchConfig(grid=128, block=64)
321+
query_fn = kernel.occupancy.max_potential_cluster_size
322+
max_potential_cluster_size = query_fn(launch_config)
323+
assert isinstance(max_potential_cluster_size, int)
324+
assert max_potential_cluster_size >= 0
325+
max_potential_cluster_size = query_fn(launch_config, stream=dev.default_stream)
326+
assert isinstance(max_potential_cluster_size, int)
327+
assert max_potential_cluster_size >= 0

0 commit comments

Comments
 (0)