3434import cutlass
3535import cutlass .cute as cute
3636from cutlass .cute .nvgpu import cpasync , tcgen05
37- from cutlass .cute .runtime import from_dlpack
3837import cutlass .torch as cutlass_torch
3938import cutlass .utils as utils
4039import cutlass .pipeline as pipeline
@@ -620,7 +619,7 @@ class SharedStorage:
620619 grid = grid ,
621620 block = [self .threads_per_cta , 1 , 1 ],
622621 cluster = (* self .cluster_shape_mn , 1 ),
623- smem = self .shared_storage .size_in_bytes (),
622+ smem = self .shared_storage .size_in_bytes (), # type: ignore[attr-defined]
624623 stream = stream ,
625624 min_blocks_per_mp = 1 ,
626625 )
@@ -1095,7 +1094,7 @@ def kernel(
10951094 #
10961095 # Tma load loop
10971096 #
1098- for k_tile in cutlass .range (0 , k_tile_cnt , 1 , unroll = 1 ):
1097+ for k_tile in cutlass .range (0 , k_tile_cnt , 1 , unroll = 1 ): # noqa: B007
10991098 tAgA_k = tAgA_slice [(None , ab_producer_state .count )]
11001099 tBgB_k = tBgB_slice [(None , ab_producer_state .count )]
11011100 tAsA_pipe = tAsA [(None , ab_producer_state .index )]
@@ -1187,7 +1186,6 @@ def kernel(
11871186 is_valid_tile = tile_info [3 ] == 1
11881187
11891188 while is_valid_tile :
1190-
11911189 #
11921190 # Prepare the mask for scaleA/scaleB
11931191 #
@@ -1219,7 +1217,7 @@ def kernel(
12191217 #
12201218 # load loop
12211219 #
1222- for k_tile in cutlass .range (0 , k_tile_cnt , 1 , unroll = 1 ):
1220+ for k_tile in cutlass .range (0 , k_tile_cnt , 1 , unroll = 1 ): # noqa: B007
12231221 #
12241222 # Slice to per mma tile index
12251223 #
@@ -1390,7 +1388,6 @@ def kernel(
13901388 is_valid_tile = tile_info [3 ] == 1
13911389
13921390 while is_valid_tile :
1393-
13941391 # Peek (try_wait) AB buffer full for k_tile = 0
13951392 ab_consumer_state .reset_count ()
13961393 peek_ab_full_status = cutlass .Boolean (1 )
@@ -1410,7 +1407,7 @@ def kernel(
14101407 #
14111408 # Mma mainloop
14121409 #
1413- for k_tile in cutlass .range (0 , k_tile_cnt , 1 , unroll = 1 ):
1410+ for k_tile in cutlass .range (0 , k_tile_cnt , 1 , unroll = 1 ): # noqa: B007
14141411 # Set tensor memory buffer for current tile
14151412 # (MMA, MMA_M, MMA_N)
14161413 tCtAcc = tCtAcc_base [(None , None , None , acc_producer_state .index )]
@@ -1591,7 +1588,6 @@ def kernel(
15911588 is_valid_tile = tile_info [3 ] == 1
15921589
15931590 while is_valid_tile :
1594-
15951591 # initialize the final accumulator
15961592 tTR_rAcc_final .fill (0.0 )
15971593
@@ -1618,7 +1614,7 @@ def kernel(
16181614 acc_consumer_state
16191615 )
16201616
1621- for k_tile in cutlass .range (0 , k_tile_cnt , 1 , unroll = 1 ):
1617+ for k_tile in cutlass .range (0 , k_tile_cnt , 1 , unroll = 1 ): # noqa: B007
16221618 # Set tensor memory buffer for current tile
16231619 # (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
16241620 tTR_tAcc = tTR_tAcc_base [
@@ -1794,7 +1790,6 @@ def kernel(
17941790
17951791 tTR_rC = None
17961792 tiled_copy_r2s = None
1797- simt_atom = None
17981793 tRS_rC = None
17991794 tRS_sC = None
18001795 bSG_sC = None
@@ -2301,7 +2296,7 @@ def _compute_stages(
23012296 sfb_count : int ,
23022297 num_smem_capacity : int ,
23032298 occupancy : int ,
2304- ) -> Tuple [int , int , int ]:
2299+ ) -> Tuple [int , int , int , int , int ]:
23052300 """Computes the number of stages for A/B/C operands based on heuristics.
23062301
23072302 :param tiled_mma: The tiled MMA object defining the core computation.
@@ -2687,7 +2682,7 @@ def __init__(
26872682 self ._use_2cta_instrs = use_2cta_instrs
26882683 self ._mma_tiler_mn = mma_tiler_mn
26892684 self ._cluster_shape_mn = cluster_shape_mn
2690-
2685+
26912686 if not BlockwiseGemmKernel .can_implement (
26922687 ab_dtype ,
26932688 acc_dtype ,
@@ -2709,13 +2704,13 @@ def __init__(
27092704
27102705 hardware_info = cutlass .utils .HardwareInfo ()
27112706 self ._max_active_clusters = min (
2712- hardware_info .get_max_active_clusters (
2713- self ._cluster_shape_mn [0 ] * self ._cluster_shape_mn [1 ]
2714- ),
2715- sm_count ,
2707+ hardware_info .get_max_active_clusters (
2708+ self ._cluster_shape_mn [0 ] * self ._cluster_shape_mn [1 ]
2709+ ),
2710+ sm_count ,
27162711 )
27172712 self ._sm_version = sm_version
2718-
2713+
27192714 @cute .jit
27202715 def __call__ (
27212716 self ,
@@ -2726,7 +2721,6 @@ def __call__(
27262721 c_ptr : cute .Pointer ,
27272722 current_stream : cuda .CUstream ,
27282723 ):
2729- #TODO(asamani): double check the shapes and layouts
27302724 a_tensor = cute .make_tensor (
27312725 a_ptr ,
27322726 layout = cute .make_ordered_layout (
@@ -2752,14 +2746,14 @@ def __call__(
27522746 sfa_ptr ,
27532747 layout = cute .make_ordered_layout (
27542748 (self ._m , math .ceil (self ._k / 128 ), self ._l ),
2755- order = (0 , 1 , 2 ), #if self._a_major == "m" else (1, 0, 2)
2749+ order = (0 , 1 , 2 ),
27562750 ),
27572751 )
27582752 sfb_tensor = cute .make_tensor (
27592753 sfb_ptr ,
27602754 layout = cute .make_ordered_layout (
2761- (math .ceil (self ._n / 128 ), math .ceil (self ._k / 128 ),self ._l ),
2762- order = (1 , 0 , 2 ), #if self._b_major == "n" else (1, 0, 2),
2755+ (math .ceil (self ._n / 128 ), math .ceil (self ._k / 128 ), self ._l ),
2756+ order = (1 , 0 , 2 ),
27632757 ),
27642758 )
27652759
@@ -2777,7 +2771,7 @@ def __call__(
27772771 self ._max_active_clusters ,
27782772 current_stream ,
27792773 )
2780-
2774+
27812775
27822776@functools .cache
27832777def get_cute_dsl_compiled_blockwise_gemm_kernel (
@@ -2831,7 +2825,7 @@ def get_cute_pointers(
28312825 sfb_tensor_gpu .data_ptr (),
28322826 c_tensor_gpu .data_ptr (),
28332827 )
2834-
2828+
28352829 a_ptr = make_ptr (
28362830 ab_dtype ,
28372831 a_data_ptr ,
@@ -2863,7 +2857,7 @@ def get_cute_pointers(
28632857 assumed_align = 16 ,
28642858 )
28652859 return [a_ptr , b_ptr , sfa_ptr , sfb_ptr , c_ptr ]
2866-
2860+
28672861 kernel = cute .compile (
28682862 BlockwiseGemmCuteDSL (
28692863 m = m ,
@@ -2887,7 +2881,6 @@ def get_cute_pointers(
28872881 cutlass_torch .current_stream (),
28882882 )
28892883
2890-
28912884 def tensor_api (
28922885 a_tensor_gpu : torch .Tensor ,
28932886 b_tensor_gpu : torch .Tensor ,
@@ -2907,7 +2900,13 @@ def tensor_api(
29072900 nonlocal kernel
29082901 kernel (
29092902 * get_cute_pointers (
2910- [a_tensor_gpu , b_tensor_gpu , sfa_tensor_gpu , sfb_tensor_gpu , c_tensor_gpu ]
2903+ [
2904+ a_tensor_gpu ,
2905+ b_tensor_gpu ,
2906+ sfa_tensor_gpu ,
2907+ sfb_tensor_gpu ,
2908+ c_tensor_gpu ,
2909+ ]
29112910 ),
29122911 current_stream ,
29132912 )
@@ -2932,7 +2931,7 @@ def blockwise_gemm(
29322931):
29332932 m , k , l = a_torch .shape
29342933 n , _ , _ = b_torch .shape
2935-
2934+
29362935 mma_tiler_mn = kwargs .pop ("mma_tiler_mn" , (128 , 128 ))
29372936 cluster_shape_mn = kwargs .pop ("cluster_shape_mn" , (1 , 1 ))
29382937 if sm_count is None :
@@ -2943,28 +2942,28 @@ def blockwise_gemm(
29432942 major , minor = get_compute_capability (a_torch .device )
29442943 if major == 11 and minor == 0 :
29452944 raise ValueError ("SM110 is not supported for cute-dsl backend." )
2946-
2945+
29472946 return get_cute_dsl_compiled_blockwise_gemm_kernel (
2948- m = m ,
2949- n = n ,
2950- k = k ,
2951- l = l ,
2952- a_major = "k" ,
2953- b_major = "k" ,
2954- c_major = "n" ,
2955- ab_dtype = get_cutlass_dtype (ab_dtype ),
2956- sf_dtype = get_cutlass_dtype (sf_dtype ),
2957- c_dtype = get_cutlass_dtype (c_dtype ),
2958- acc_dtype = get_cutlass_dtype (acc_dtype ),
2959- use_2cta_instrs = use_2cta_instrs ,
2960- mma_tiler_mn = mma_tiler_mn ,
2961- cluster_shape_mn = cluster_shape_mn ,
2962- sm_count = sm_count ,
2963- sm_version = f"sm_{ major } { minor } " ,
2964- )(
2965- a_tensor_gpu = a_torch ,
2966- b_tensor_gpu = b_torch ,
2967- sfa_tensor_gpu = sfa_torch ,
2968- sfb_tensor_gpu = sfb_torch ,
2969- c_tensor_gpu = c_torch ,
2970- )
2947+ m = m ,
2948+ n = n ,
2949+ k = k ,
2950+ l = l ,
2951+ a_major = "k" ,
2952+ b_major = "k" ,
2953+ c_major = "n" ,
2954+ ab_dtype = get_cutlass_dtype (ab_dtype ),
2955+ sf_dtype = get_cutlass_dtype (sf_dtype ),
2956+ c_dtype = get_cutlass_dtype (c_dtype ),
2957+ acc_dtype = get_cutlass_dtype (acc_dtype ),
2958+ use_2cta_instrs = use_2cta_instrs ,
2959+ mma_tiler_mn = mma_tiler_mn ,
2960+ cluster_shape_mn = cluster_shape_mn ,
2961+ sm_count = sm_count ,
2962+ sm_version = f"sm_{ major } { minor } " ,
2963+ )(
2964+ a_tensor_gpu = a_torch ,
2965+ b_tensor_gpu = b_torch ,
2966+ sfa_tensor_gpu = sfa_torch ,
2967+ sfb_tensor_gpu = sfb_torch ,
2968+ c_tensor_gpu = c_torch ,
2969+ )
0 commit comments