@@ -211,16 +211,25 @@ def max_active_blocks_per_multiprocessor(self, block_size: int, dynamic_shared_m
211211 )
212212
213213 def max_potential_block_size (
214- self , dynamic_shared_memory_size : int , block_size_limit : int
214+ self , dynamic_shared_memory_needed : Union [ int , driver . CUoccupancyB2DSize ] , block_size_limit : int
215215 ) -> MaxPotentialBlockSizeOccupancyResult :
216216 """MaxPotentialBlockSizeOccupancyResult: Suggested launch configuration for reasonable occupancy.
217217
218218 Returns the minimum grid size needed to achieve the maximum occupancy and
219219 the maximum block size that can achieve the maximum occupancy.
220220 """
221- min_grid_size , max_block_size = handle_return (
222- driver .cuOccupancyMaxPotentialBlockSize (self ._handle , None , dynamic_shared_memory_size , block_size_limit )
223- )
221+ if isinstance (dynamic_shared_memory_needed , int ):
222+ min_grid_size , max_block_size = handle_return (
223+ driver .cuOccupancyMaxPotentialBlockSize (
224+ self ._handle , None , dynamic_shared_memory_needed , block_size_limit
225+ )
226+ )
227+ else :
228+ min_grid_size , max_block_size = handle_return (
229+ driver .cuOccupancyMaxPotentialBlockSize (
230+ self ._handle , dynamic_shared_memory_needed .getPtr (), 0 , block_size_limit
231+ )
232+ )
224233 return MaxPotentialBlockSizeOccupancyResult (min_grid_size = min_grid_size , max_block_size = max_block_size )
225234
226235 def available_dynamic_shared_memory_per_block (self , num_blocks_per_multiprocessor : int , block_size : int ) -> int :
0 commit comments