-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Open
Labels
Description
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
compiler takes an unusually long time to compile this small kernel
Steps/Code to reproduce bug
run this snippet:
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.torch import dtype as cute_dtype_to_torch_dtype
cute_ab_dtype = cute.Float16
cute_accumulation_dtype = cute.Float16
ab_dtype = cute_dtype_to_torch_dtype(cute_ab_dtype)
acc_dtype = cute_dtype_to_torch_dtype(cute_accumulation_dtype)
MMA_TILER = MMA_M, MMA_N, MMA_K = (128, 256, 16)
@cute.kernel()
def kernel(
mA,
mB,
tiled_copy,
StorageType: cutlass.Constexpr,
block_sA_layout: cute.ComposedLayout,
block_sB_layout: cute.ComposedLayout
):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(StorageType)
tile_coords = (bidx, bidy, None)
gA = cute.local_tile(mA, MMA_TILER, tile_coords, proj=(1, None, 1)) # (MmaTile_M, MmaTile_K, Tiles_K)
gB = cute.local_tile(mB, MMA_TILER, tile_coords, proj=(None, 1, 1)) # (MmaTile_N, MmaTile_K, Tiles_K)
tCsA = storage.sA.get_tensor(block_sA_layout.outer, swizzle=block_sA_layout.inner)[None, None, None, 0] # swizzled (MmaA, NumMma_M, NumMma_K) (we get first tile because pipeline=1)
tCsB = storage.sB.get_tensor(block_sB_layout.outer, swizzle=block_sB_layout.inner)[None, None, None, 0] # swizzled (MmaB, NumMma_N, NumMma_K) (we get first tile because pipeline=1)
thr_tiled_copy = tiled_copy.get_slice(tidx)
thr_sA = thr_tiled_copy.partition_D(cute.flatten(tCsA[None, 0, 0]))
thr_sB = thr_tiled_copy.partition_D(cute.flatten(tCsB[None, 0, 0]))
num_k_blocks = gA.shape[2]
for i in cutlass.range(num_k_blocks):
thr_gA = thr_tiled_copy.partition_S(gA[None, None, i])
cute.copy(thr_tiled_copy, thr_gA, thr_sA)
thr_gB = thr_tiled_copy.partition_S(gB[None, None, i])
cute.copy(thr_tiled_copy, thr_gB, thr_sB)
@cute.jit
def launcher(a, b):
tiled_copy = cute.make_tiled_copy_tv(
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cute_ab_dtype),
cute.make_layout((4*cute.arch.WARP_SIZE,), stride=(MMA_N,)),
cute.make_layout((MMA_N,), stride=(1,))
)
tiled_mma = sm100_utils.make_trivial_tiled_mma(
cute_ab_dtype,
cute.nvgpu.tcgen05.OperandMajorMode.K,
cute.nvgpu.tcgen05.OperandMajorMode.K,
cute_accumulation_dtype,
cute.nvgpu.tcgen05.CtaGroup.ONE,
MMA_TILER[:2],
)
block_sA_layout = sm100_utils.make_smem_layout_a(
tiled_mma,
MMA_TILER,
cute_ab_dtype,
1,
)
block_sB_layout = sm100_utils.make_smem_layout_b(
tiled_mma,
MMA_TILER,
cute_ab_dtype,
1,
)
@cute.struct
class SharedStorage:
tmem_ptr_buffer: cute.Int32 # tmem ptrs are 32 bits
mma_barrier: cute.Int64
sA: cute.struct.Align[
cute.struct.MemRange[
cute_ab_dtype, cute.cosize(block_sA_layout)
],
16,
]
sB: cute.struct.Align[
cute.struct.MemRange[
cute_ab_dtype, cute.cosize(block_sB_layout)
],
16,
]
kernel(
a, b, tiled_copy, SharedStorage, block_sA_layout, block_sB_layout
).launch(
grid=(1, 1, 1),
block=(4*cute.arch.WARP_SIZE, 1, 1) # you need a WG for the epilogue because each warp can only access 32lanes of TMEM
)
if __name__ == "__main__":
import torch
M, N, K = 128, 256, 16
a = torch.randn(M, K, device="cuda", dtype=ab_dtype)
b = torch.randn(N, K, device="cuda", dtype=ab_dtype)
launcher(a, b)
torch.cuda.synchronize()
Expected behavior
I'm used to big kernels taking seconds to compile, this small one clearly does something weird to the compiler.
Environment details (please complete the following information):
gb200, torch==2.8.0+cu129, nvidia-cutlass-dsl==4.2.1
fengxie and brandon-yujie-sun