Skip to content

[BUG][CuTeDSL] compiler takes 2minutes to compile the following 30L kernel #2677

@sfc-gh-lpaille

Description

@sfc-gh-lpaille

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug
compiler takes an unusually long time to compile this small kernel

Steps/Code to reproduce bug
run this snippet:

import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.torch import dtype as cute_dtype_to_torch_dtype


cute_ab_dtype = cute.Float16
cute_accumulation_dtype = cute.Float16
ab_dtype = cute_dtype_to_torch_dtype(cute_ab_dtype)
acc_dtype = cute_dtype_to_torch_dtype(cute_accumulation_dtype)

MMA_TILER = MMA_M, MMA_N, MMA_K = (128, 256, 16)


@cute.kernel()
def kernel(
    mA, 
    mB, 
    tiled_copy,
    StorageType: cutlass.Constexpr, 
    block_sA_layout: cute.ComposedLayout, 
    block_sB_layout: cute.ComposedLayout
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, bidy, _ = cute.arch.block_idx()

    smem = cutlass.utils.SmemAllocator()
    storage = smem.allocate(StorageType)

    tile_coords = (bidx, bidy, None)
    gA = cute.local_tile(mA, MMA_TILER, tile_coords, proj=(1, None, 1)) # (MmaTile_M, MmaTile_K, Tiles_K)
    gB = cute.local_tile(mB, MMA_TILER, tile_coords, proj=(None, 1, 1)) # (MmaTile_N, MmaTile_K, Tiles_K)


    tCsA = storage.sA.get_tensor(block_sA_layout.outer, swizzle=block_sA_layout.inner)[None, None, None, 0] # swizzled (MmaA, NumMma_M, NumMma_K) (we get first tile because pipeline=1)
    tCsB = storage.sB.get_tensor(block_sB_layout.outer, swizzle=block_sB_layout.inner)[None, None, None, 0] # swizzled (MmaB, NumMma_N, NumMma_K) (we get first tile because pipeline=1)

    
    thr_tiled_copy = tiled_copy.get_slice(tidx)
    thr_sA = thr_tiled_copy.partition_D(cute.flatten(tCsA[None, 0, 0]))
    thr_sB = thr_tiled_copy.partition_D(cute.flatten(tCsB[None, 0, 0]))

    num_k_blocks = gA.shape[2]
    for i in cutlass.range(num_k_blocks):
        thr_gA = thr_tiled_copy.partition_S(gA[None, None, i])
        cute.copy(thr_tiled_copy, thr_gA, thr_sA)
        thr_gB = thr_tiled_copy.partition_S(gB[None, None, i])
        cute.copy(thr_tiled_copy, thr_gB, thr_sB)



@cute.jit
def launcher(a, b):
    tiled_copy = cute.make_tiled_copy_tv(
        cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cute_ab_dtype),
        cute.make_layout((4*cute.arch.WARP_SIZE,), stride=(MMA_N,)),
        cute.make_layout((MMA_N,), stride=(1,))
    )
    tiled_mma = sm100_utils.make_trivial_tiled_mma(
        cute_ab_dtype,
        cute.nvgpu.tcgen05.OperandMajorMode.K,
        cute.nvgpu.tcgen05.OperandMajorMode.K,
        cute_accumulation_dtype,
        cute.nvgpu.tcgen05.CtaGroup.ONE,
        MMA_TILER[:2],
    )

    block_sA_layout = sm100_utils.make_smem_layout_a(
        tiled_mma,
        MMA_TILER,
        cute_ab_dtype,
        1,
    )
    block_sB_layout = sm100_utils.make_smem_layout_b(
        tiled_mma,
        MMA_TILER,
        cute_ab_dtype,
        1,
    )

    @cute.struct
    class SharedStorage:
        tmem_ptr_buffer: cute.Int32 # tmem ptrs are 32 bits
        mma_barrier: cute.Int64
        sA: cute.struct.Align[
            cute.struct.MemRange[
                cute_ab_dtype, cute.cosize(block_sA_layout)
            ],
            16,
        ]
        sB: cute.struct.Align[
            cute.struct.MemRange[
                cute_ab_dtype, cute.cosize(block_sB_layout)
            ],
            16,
        ]
    kernel(
        a, b, tiled_copy, SharedStorage, block_sA_layout, block_sB_layout
    ).launch(
        grid=(1, 1, 1), 
        block=(4*cute.arch.WARP_SIZE, 1, 1) # you need a WG for the epilogue because each warp can only access 32lanes of TMEM
    )

if __name__ == "__main__":
    import torch
    M, N, K = 128, 256, 16
    a = torch.randn(M, K, device="cuda", dtype=ab_dtype)
    b = torch.randn(N, K, device="cuda", dtype=ab_dtype)
    launcher(a, b)
    torch.cuda.synchronize()

Expected behavior
I'm used to big kernels taking seconds to compile, this small one clearly does something weird to the compiler.

Environment details (please complete the following information):
gb200, torch==2.8.0+cu129, nvidia-cutlass-dsl==4.2.1

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions