1111from cuda .core .experimental import Device , ObjectCode , Program , ProgramOptions , system
1212from cuda .core .experimental ._utils .cuda_utils import CUDAError , driver , get_binding_version , handle_return
1313
14+ try :
15+ import numba
16+ except ImportError :
17+ numba = None
18+
1419SAXPY_KERNEL = r"""
1520template<typename T>
1621__global__ void saxpy(const T a,
@@ -269,9 +274,10 @@ def test_saxpy_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel, b
269274 assert kernel .attributes .num_regs () * num_blocks_per_sm <= dev_props .max_registers_per_multiprocessor
270275
271276
272- @pytest .mark .parametrize ("block_size_limit" , [32 , 64 , 96 , 120 , 128 , 256 ])
277+ @pytest .mark .parametrize ("block_size_limit" , [32 , 64 , 96 , 120 , 128 , 256 , 0 ])
273278@pytest .mark .parametrize ("smem_size_per_block" , [0 , 32 , 4096 ])
274- def test_saxpy_occupancy_max_potential_block_size (get_saxpy_kernel , block_size_limit , smem_size_per_block ):
279+ def test_saxpy_occupancy_max_potential_block_size_constant (get_saxpy_kernel , block_size_limit , smem_size_per_block ):
280+ """Tests use case when shared memory needed is independent on the block size"""
275281 kernel , _ = get_saxpy_kernel
276282 dev_props = Device ().properties
277283 assert block_size_limit <= dev_props .max_threads_per_block
@@ -284,7 +290,40 @@ def test_saxpy_occupancy_max_potential_block_size(get_saxpy_kernel, block_size_l
284290 assert isinstance (max_block_size , int )
285291 assert min_grid_size > 0
286292 assert max_block_size > 0
287- assert max_block_size <= block_size_limit
293+ if block_size_limit > 0 :
294+ assert max_block_size <= block_size_limit
295+ else :
296+ assert max_block_size <= dev_props .max_threads_per_block
297+ assert min_grid_size == config_data .min_grid_size
298+ assert max_block_size == config_data .max_block_size
299+
300+
301+ @pytest .mark .skipif (numba is None , reason = "Test requires numba to be installed" )
302+ @pytest .mark .parametrize ("block_size_limit" , [32 , 64 , 96 , 120 , 128 , 277 , 0 ])
303+ def test_saxpy_occupancy_max_potential_block_size_b2dsize (get_saxpy_kernel , block_size_limit ):
304+ """Tests use case when shared memory needed depends on the block size"""
305+ kernel , _ = get_saxpy_kernel
306+
307+ def shared_memory_needed (block_size : numba .intc ) -> numba .size_t :
308+ "Size of dynamic shared memory needed by kernel of this block size"
309+ return 1024 * (block_size // 32 )
310+
311+ b2dsize_sig = numba .size_t (numba .intc )
312+ dsmem_needed_cfunc = numba .cfunc (b2dsize_sig )(shared_memory_needed )
313+ fn_ptr = ctypes .cast (dsmem_needed_cfunc .ctypes , ctypes .c_void_p ).value
314+ b2dsize_fn = driver .CUoccupancyB2DSize (_ptr = fn_ptr )
315+ config_data = kernel .occupancy .max_potential_block_size (b2dsize_fn , block_size_limit )
316+ dev_props = Device ().properties
317+ assert block_size_limit <= dev_props .max_threads_per_block
318+ min_grid_size , max_block_size = config_data
319+ assert isinstance (min_grid_size , int )
320+ assert isinstance (max_block_size , int )
321+ assert min_grid_size > 0
322+ assert max_block_size > 0
323+ if block_size_limit > 0 :
324+ assert max_block_size <= block_size_limit
325+ else :
326+ assert max_block_size <= dev_props .max_threads_per_block
288327
289328
290329@pytest .mark .parametrize ("num_blocks_per_sm, block_size" , [(4 , 32 ), (2 , 64 ), (2 , 96 ), (3 , 120 ), (2 , 128 ), (1 , 256 )])
0 commit comments