From a56227659b0b18bd1a38ea308aad84399ebfcdc0 Mon Sep 17 00:00:00 2001 From: "Lu,Chengjun" Date: Mon, 21 Jul 2025 15:49:00 +0000 Subject: [PATCH] [LoadStoreOpToLLVM] Refactor block load lowering of tt.load with tensor pointer. Signed-off-by: Lu,Chengjun --- python/test/unit/intel/test_block_store.py | 34 +- .../tensor-pointer-load-block-2d.mlir | 245 ++++- .../LoadStoreOpToLLVM.cpp | 983 +++++++----------- 3 files changed, 637 insertions(+), 625 deletions(-) diff --git a/python/test/unit/intel/test_block_store.py b/python/test/unit/intel/test_block_store.py index f1fbea7bae..6656558a51 100644 --- a/python/test/unit/intel/test_block_store.py +++ b/python/test/unit/intel/test_block_store.py @@ -133,20 +133,24 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io'] if block_ptr: + load_ops = f""" + %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array}} : > + %store_val = tt.load %src_ptr {{boundaryCheck = array, padding = 1 : i32}} : !tt.ptr> + """ store_ops = f""" - %M_i64 = arith.constant {M} : i64 - %N_i64 = arith.constant {N} : i64 - %c1_i64 = arith.constant 1 : i64 - %c0_i32 = arith.constant 0 : i32 - - %blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array}} : > - tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array}} : !tt.ptr> + %dst_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array}} : > + tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array}} : !tt.ptr> """ else: + load_ops = f""" + %src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> + %src_ptr = tt.addptr %src_base, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout> + %store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> + """ store_ops = f""" - %12 = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> - %13 = tt.addptr %12, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout> - tt.store %13, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> + %dst_base = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> + %dst_ptr = tt.addptr %dst_base, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout> + tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> """ ir = f""" @@ -154,6 +158,10 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl module attributes {{{"ttig.support_sg_2d_block," if support_block_io else ""} "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{ tt.func public @block_store(%src: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + %M_i64 = arith.constant {M} : i64 + %N_i64 = arith.constant {N} : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 %stride = arith.constant dense<{N}> : tensor<{M}x1xi32, #layout> %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{M}x1xi32, #layout> @@ -163,9 +171,7 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl %6 = tt.broadcast %3 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout> %7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout> %8 = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout> - %9 = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> - %10 = tt.addptr %9, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout> - %store_val = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> + {load_ops} {store_ops} @@ -191,3 +197,5 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl if support_block_io: assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel.asm['llir'] or 'GenISA.LSC2DBlockWrite' in kernel.asm['llir'] + if not block_ptr: + assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir'] diff --git a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir index 1991da0df8..74ef28a5c4 100644 --- a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir +++ b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir @@ -57,7 +57,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt %65 = tt.splat %64 : i32 -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %66 = arith.cmpi slt, %38, %65 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %67 = tt.broadcast %66 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<128x64xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - // CHECK-COUNT-16: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 2 + // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 1, transpose = false, vnni_transform = false, cache_control = Default} %68 = tt.load %60, %67, %cst_3 {ttig.block_io = "row_major"} : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %74 = tt.addptr %60, %cst_0 : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<128x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %76 = arith.addi %58, %c1_i32 : i32 @@ -69,7 +69,73 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt // ----- -#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2]}> +module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { + tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 4 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} { + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %15 = arith.muli %13, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %24 = tt.splat %15 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>%31 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %45 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<512> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %47 = arith.muli %45, %cst_2 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %48 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %49 = tt.broadcast %47 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %50 = tt.broadcast %48 : tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %51 = arith.addi %49, %50 : tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %52 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %53 = tt.addptr %52, %51 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %54 = arith.addi %arg5, %c63_i32 : i32 + %55 = arith.divsi %54, %c64_i32 : i32 + %56 = arith.muli %arg7, %c64_i32 : i32 + %57 = tt.splat %56 : i32 -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + cf.br ^bb1(%c0_i32, %53 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb1(%58: i32, %61: tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>): // 2 preds: ^bb0, ^bb2 + %62 = arith.cmpi slt, %58, %55 : i32 + cf.cond_br %62, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %63 = arith.muli %58, %c64_i32 : i32 + %64 = arith.subi %arg5, %63 : i32 + %69 = tt.splat %64 : i32 -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %70 = arith.cmpi slt, %45, %69 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %71 = tt.broadcast %70 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK-COUNT-32: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 4, v_blocks = 2, transpose = false, vnni_transform = false, cache_control = Default} + %72 = tt.load %61, %71, %cst_4 {ttig.block_io = "row_major"} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.addptr %61, %57 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %76 = arith.addi %58, %c1_i32 : i32 + cf.br ^bb1(%76, %75 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb3: // pred: ^bb1 + tt.return + } +} + +// ----- + +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 1]}> module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} { %c63_i32 = arith.constant 63 : i32 @@ -123,7 +189,73 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt %69 = tt.splat %64 : i32 -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %70 = arith.cmpi slt, %45, %69 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %71 = tt.broadcast %70 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 1 + // CHECK-COUNT-16: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 1, transpose = false, vnni_transform = true, cache_control = Default} + %72 = tt.load %61, %71, %cst_4 {ttig.block_io = "row_major"} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.addptr %61, %57 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %76 = arith.addi %58, %c1_i32 : i32 + cf.br ^bb1(%76, %75 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb3: // pred: ^bb1 + tt.return + } +} + +// ----- + +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2]}> +module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { + tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} { + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %15 = arith.muli %13, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %24 = tt.splat %15 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>%31 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %45 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<512> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %47 = arith.muli %45, %cst_2 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %48 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %49 = tt.broadcast %47 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %50 = tt.broadcast %48 : tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %51 = arith.addi %49, %50 : tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %52 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %53 = tt.addptr %52, %51 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %54 = arith.addi %arg5, %c63_i32 : i32 + %55 = arith.divsi %54, %c64_i32 : i32 + %56 = arith.muli %arg7, %c64_i32 : i32 + %57 = tt.splat %56 : i32 -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + cf.br ^bb1(%c0_i32, %53 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb1(%58: i32, %61: tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>): // 2 preds: ^bb0, ^bb2 + %62 = arith.cmpi slt, %58, %55 : i32 + cf.cond_br %62, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %63 = arith.muli %58, %c64_i32 : i32 + %64 = arith.subi %arg5, %63 : i32 + %69 = tt.splat %64 : i32 -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %70 = arith.cmpi slt, %45, %69 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %71 = tt.broadcast %70 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2, transpose = false, vnni_transform = true, cache_control = Default} %72 = tt.load %61, %71, %cst_4 {ttig.block_io = "row_major"} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %75 = tt.addptr %61, %57 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %76 = arith.addi %58, %c1_i32 : i32 @@ -256,44 +388,105 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[TOP_LEFT_MASK_BOOL_32:.*]] = llvm.extractvalue {{.*}}[32] : !llvm.struct<(i1, i1, {{.*}} // CHECK: %[[TOP_LEFT_MASK_BOOL_64:.*]] = llvm.extractvalue {{.*}}[64] : !llvm.struct<(i1, i1, {{.*}} // CHECK: %[[TOP_LEFT_MASK_BOOL_96:.*]] = llvm.extractvalue {{.*}}[96] : !llvm.struct<(i1, i1, {{.*}} - + // CHECK: %[[VAL_2878:.*]] = llvm.extractvalue {{.*}}[126] : !llvm.struct<(f16, f16, {{.*}} + // CHECK: %[[VAL_2879:.*]] = llvm.extractvalue {{.*}}[127] : !llvm.struct<(f16, f16, {{.*}} // CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_2886:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFORM_PTR:.*]] = llvm.inttoptr %[[VAL_2886]] : i64 to !llvm.ptr<1> // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_0:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8 - // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_0]], %[[CST0_1]]) + // CHECK: %[[CST0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_MASK:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8 + // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK]], %[[CST0_2]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 - // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 + // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 - // CHECK: llvm.select {{.*}}, %[[LOAD_0]], {{.*}} : i1, vector<32xf16> - + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + + // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFORM_PTR:.*]] = llvm.inttoptr %[[VAL_3046]] : i64 to !llvm.ptr<1> // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8 - // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_1]], %[[CST0_1]]) + // CHECK: %[[CST0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_MASK:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8 + // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK]], %[[CST0_2]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 - // CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 - // CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 - // CHECK: llvm.select {{.*}}, %[[LOAD_1]], {{.*}} : i1, vector<32xf16> - + // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 + // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + + // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFORM_PTR:.*]] = llvm.inttoptr %[[VAL_3046]] : i64 to !llvm.ptr<1> // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8 - // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_2]], %[[CST0_1]]) + // CHECK: %[[CST0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_MASK:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8 + // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK]], %[[CST0_2]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 - // CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 - // CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 - // CHECK: llvm.select {{.*}}, %[[LOAD_2]], {{.*}} : i1, vector<32xf16> - + // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 + // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + + // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFORM_PTR:.*]] = llvm.inttoptr %[[VAL_3046]] : i64 to !llvm.ptr<1> // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_3:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8 - // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_3]], %[[CST0_1]]) + // CHECK: %[[CST0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_MASK:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8 + // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK]], %[[CST0_2]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 - // CHECK: %[[BASE_Y_3:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 - // CHECK: %[[LOAD_3:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_3]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 - // CHECK: llvm.select {{.*}}, %[[LOAD_3]], {{.*}} : i1, vector<32xf16> + // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 + // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> %11 = tt.load %10, %a_mask, %a_other {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr, #mma> tt.return diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 6b574c35bb..4bcd48d5c5 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -24,6 +24,24 @@ using namespace mlir::triton::gpu::intel; #define S(v) StringAttr::get(ctx, (v)) +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_clz(unsigned x) { + unsigned long r; + _BitScanReverse(&r, x); + return static_cast(r ^ 31); +} + +static int __builtin_ctz(unsigned x) { + unsigned long r; + _BitScanForward(&r, x); + return static_cast(r); +} + +#endif + namespace { Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { @@ -1782,215 +1800,110 @@ struct LoadOpToBlockIOConversion LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - if (!isBlockIOCandidate(op)) + static const bool enableBlockIOForAllLayout = + triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS"); + if (!isBlockIOCandidate(op, enableBlockIOForAllLayout)) return failure(); - // 2D block io lowering steps: - // 1. Get the 2 dims for 2D block io: one of the dimension chosen correspond - // to the dimension where the access pattern has stride one. The other - // dimension should be the one with the constancy stride. - // 2. Check the DPAS layout, fallback to gather IO lowering if the block - // IO is not supported for the layout. TODO: to generalize the code for - // different layout with the linear layout. - // 3. Compute the maximum tile size for the 2D block io from the layout - // information. - // 4. Generates the 2D block IO instructions. - // 5. Unpacked the loaded values into expected order required by the layout. - - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto b = TritonLLVMOpBuilder(loc, rewriter); Value ptr = op.getPtr(); if (isTensorPointerType(ptr.getType())) return rewriteTensorPointerLoad(op, adaptor, rewriter); - Value mask = op.getMask(); - Value other = op.getOther(); - Type resultType = op.getType(); - auto tensorType = cast(resultType); - - // Step 1: Right now we only support 2D rank matrix of row major or column - // major. - const bool memoryRowMajor = isMemoryRowMajor(op); - DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); - + Value value = op.getResult(); + auto tensorType = cast(value.getType()); + // Get the max tile shape supported by the layout. Attribute encoding = tensorType.getEncoding(); std::optional llEncoding = cast(encoding).toLinearLayout( tensorType.getShape()); - assert(llEncoding.has_value() && "invalid dot layout to linear layout"); - auto llAttr = LinearEncodingAttr::get(rewriter.getContext(), *llEncoding); - SmallVector threadOrder(llAttr.getThreadOrder()); - size_t rank = threadOrder.size(); - if (rank != 2) { - // only support rank of 2 for now. - return failure(); - } - const bool valueRowMajor = - (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); - assert((valueRowMajor || - (threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) && - "Only row_major or column_major is allowed"); - const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - - // Step 2: Right now we only support DPAS related layout to simplify the - // lowering. - Type eltTy = tensorType.getElementType(); - unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - DpasEncodingAttr dpasLayout = getDpasLayout(tensorType); - const ArrayRef tensorShape = tensorType.getShape(); - unsigned numElems = getTotalElemsPerThread(resultType); - SmallVector repetitons = - dpasLayout.getDPASRepetitions(tensorShape, opIdx); - assert(repetitons.size() == 3 && - "getDPASRepetitions always return rank 3 size"); - assert(repetitons[0] == 1 && "Only supports rank of 2 for now"); - SmallVector numReps{repetitons[1], repetitons[2]}; - ArrayRef warpsPerCTA = dpasLayout.getWarpsPerCTA(); - SmallVector dpasWarpsOrder = - getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); - unsigned threadsPerWarp = - product(getThreadsPerWarp(dpasLayout, tensorShape)); - - Value warpId = rewriter.create( - loc, i32_ty, - rewriter.create(loc, /*upperBound=*/nullptr)); - - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); + assert(llEncoding.has_value() && + "unexpected failure when getting linear layout"); - // By default, use the unpacked type for the 2D load result type. - Type loadResultElemType = typeConverter->convertType(eltTy); - bool usePackedType = false; - unsigned packedElemsNum = 1; - // The tensor values are distributed as DotOp layout of DPAS. - // If the element size of the tensor matches the DPAS packed layout, then - // use the packed type for the 2D load result type. For example, - // The intermediate ops generated by ConvertTritonGPUToLLVM: - // %0 = load_2d %ptr : vector<8 x i32> - // %1 = bitcast %0 : vector<8 x i32> -> vector<16 x f16> - // %2 = bitcast %1 : vector<16 x f16> -> vector<8 x i32> - // %3 = dpas %2 - // And the LLVM dialect optimization pass can eliminate the duplicated - // bitcast. Then there is a shortcut to use the load result directly as the - // input operands to DPAS. - // TODO: add support for int4 and int2. + auto [tileHeight, tileWidth, numPackedVals, vBlocks, rowDim, colDim, + regPackedBases] = + getBlockIOTileSize<32 /*MAX_TILE_HEIGHT*/>(*llEncoding); - // OperandA: outer dim -> M, inner dim -> K. - // OperandB: outer dim -> N, inner dim -> K. - // OperandC: outer dim -> M, inner dim -> N. - // Round the warp id fit into the tensor shape. - unsigned dimOuter; - unsigned dimInner; - SmallVector repCluster(dpasLayout.getRepCluster()); - SmallVector warpShape; - SmallVector dpasInstShape; + // no valid tile shape for 2D block IO. + if (colDim < 0) + return failure(); - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: { - warpShape = std::move(dpasLayout.getShapeA()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeA()); - dimOuter = rank - 2; - dimInner = rank - 1; - repCluster[dimInner] = 1; - - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = elemSizeInBits == 32 ? i32_ty : i16_ty; - packedElemsNum = opsPerChannel == 4 ? 2 : 1; - usePackedType = true; - } else if (opsPerChannel == 4) { - packedElemsNum = 2; - unsigned packedBitWidht = elemSizeInBits * packedElemsNum; - if (packedBitWidht > 64) { - // Be conservative to avoid the packed type exceeds 64 bits. - return failure(); - } - // Need to pack two column into one to work around vectorization - // limitation. - loadResultElemType = int_ty(packedBitWidht); - usePackedType = true; - } - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - warpShape = std::move(dpasLayout.getShapeB()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeB()); - dimOuter = rank - 1; - dimInner = rank - 2; - repCluster[dimInner] = 1; - - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = i32_ty; - packedElemsNum = opsPerChannel; - usePackedType = true; - } - } break; - case DpasEncodingAttr::OpIdx::OperandC: - warpShape = std::move(dpasLayout.getShapeC()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeC()); - dimOuter = rank - 2; - dimInner = rank - 1; - usePackedType = false; + Type eltTy = getTypeConverter()->convertType(tensorType.getElementType()); + unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); + unsigned packedElemSizeInBits = elemSizeInBits * numPackedVals; + // 2D block store supports 64 bits element at most. + if (packedElemSizeInBits > 64) + return failure(); + // Tile width is not changeable. Return failure if it is not supported by + // HW. + switch (packedElemSizeInBits) { + case 8: + if (tileWidth < 4 || tileWidth > 64) + return failure(); break; - default: - llvm_unreachable("unknown DPAS operands index type."); + case 16: + if (tileWidth < 2 || tileWidth > 32) + return failure(); + break; + case 32: + if (tileWidth > 16) + return failure(); + break; + case 64: + if (tileWidth > 8) + return failure(); break; + default: + // invalid element type for 2D block store. + return failure(); } - unsigned elemsPerLanePerDPASInst = - product(dpasInstShape) / threadsPerWarp; - LLVMTypeConverter *typeConverter = getTypeConverter(); - Type unpackedDPASOperandType = LLVM::getVectorType( - typeConverter->convertType(eltTy), elemsPerLanePerDPASInst); - unsigned packedElemsPerLanePerDPASInst = - elemsPerLanePerDPASInst / packedElemsNum; - Type packedDPASOperandType = - LLVM::getVectorType(loadResultElemType, packedElemsPerLanePerDPASInst); + // 2D block load supports 64 bytes per row at most. + unsigned totalBytesPerRowPerMatrix = tileWidth * packedElemSizeInBits / 8; + vBlocks = std::min(vBlocks, (int)(64 / totalBytesPerRowPerMatrix)); + vBlocks = std::min(4, vBlocks); + // HW issue for vblock = 4 + vBlocks = vBlocks == 4 ? 1 : vBlocks; - unsigned outerDimTileNum = - mlir::ceil(tensorShape[dimOuter], warpShape[dimOuter]); - unsigned outerDimWarpNum = - std::min(warpsPerCTA[dimOuter], outerDimTileNum); - Value outerDimWarpId = - b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); - unsigned innerDimRequiredWarpNum = - mlir::ceil(tensorShape[dimInner], warpShape[dimInner]); - unsigned innerDimWarpNum = - std::min(warpsPerCTA[dimInner], innerDimRequiredWarpNum); - - // Step 3: Get the tile size of load. - unsigned tileWidth = dpasInstShape[threadOrder[rank - 2]]; - unsigned tileHeight = dpasInstShape[threadOrder[rank - 1]]; - unsigned vBlocks = 1; - unsigned numOperandsOuterDimPerLoad = 1; - unsigned numOperandsInnerDimPerLoad = 1; - unsigned maskConstancyHor = 1, maskConstancyVer = 1; - unsigned instWidth = dpasInstShape[threadOrder[rank - 2]]; - unsigned instHeight = dpasInstShape[threadOrder[rank - 1]]; + // TODO: use the axis info to general the handling for both regular pointer + // and block pointer. + const bool memoryRowMajor = isMemoryRowMajor(op); + unsigned contiguousDim = memoryRowMajor ? 1 : 0; + const bool isTransposeRequired = contiguousDim != colDim; - bool otherIsSplatConstInt = false; - int64_t splatVal = 0; + if (isTransposeRequired) { + // TODO: support load column major data. + return failure(); + } - std::map, Value> ptrs; - std::map, Value> masks; - std::map, Value> others; + Location loc = op.getLoc(); + MLIRContext *ctx = op.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto typeConverter = getTypeConverter(); + StringAttr kRegister = S("register"); + std::vector> bases(regPackedBases->size()); + llvm::transform(*regPackedBases, bases.begin(), + [&](int base) { return std::vector{base}; }); + LinearLayout regMapping({{kRegister, bases}}, + {{kRegister, llEncoding->getInDimSize(kRegister)}}, + /*requireSurjective=*/true); - Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); - Value llOther = adaptor.getOther(); + unsigned numElems = getTotalElemsPerThread(op.getType()); + unsigned threadsPerWarp = + TritonGPUDialect::getThreadsPerWarp(op->getParentOfType()); SmallVector ptrElems, maskElems, otherElems; + Value baseWidth, baseHeight, pitch; + unsigned baseHeightInt; + Value mask = op.getMask(); + Value other = op.getOther(); // Get the LLVM values for pointers + Value llPtr = adaptor.getPtr(); ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems && "the number of pointer values is not matched with the number of " "elements"); + Value llMask = adaptor.getMask(); // Get the LLVM values for mask if (llMask) { maskElems = unpackLLElements(loc, llMask, rewriter); @@ -2000,445 +1913,343 @@ struct LoadOpToBlockIOConversion auto axisInfo = const_cast(axisAnalysisPass) .getAxisInfo(mask); + unsigned maskConstancyHor = 1, maskConstancyVer = 1; if (axisInfo) { - maskConstancyHor = axisInfo->getConstancy(rank - 1); - maskConstancyVer = axisInfo->getConstancy(rank - 2); - } else { - maskConstancyHor = 1; - maskConstancyVer = 1; + maskConstancyHor = axisInfo->getConstancy(colDim); + if ((maskConstancyHor - 1) & maskConstancyHor) + return failure(); // The maskConstancyHor has to be power of 2 for + // block IO. + maskConstancyVer = axisInfo->getConstancy(rowDim); + if ((maskConstancyVer - 1) & maskConstancyVer) + return failure(); // The maskConstancyVer has to be power of 2 for + // block IO. } - } else { - // no mask - maskConstancyHor = std::numeric_limits::max(); - maskConstancyVer = std::numeric_limits::max(); - } - // Check the constancy of the mask support to load the memory in 2D block. - if (!(maskConstancyHor >= instWidth && maskConstancyVer >= instHeight)) - return failure(); + // Check the constancy of the mask support to load the memory in 2D block. + if (!(maskConstancyHor >= (tileWidth * numPackedVals))) + return failure(); // The tileWidth and numPackedVals is not changeable + // for now. + + // Adjust the vBlock to fit the constancy of mask + vBlocks = std::min(vBlocks, mlir::ceil(maskConstancyHor, + tileWidth * numPackedVals)); + assert(((vBlocks - 1) & vBlocks) == 0 && + "the vBlocks has to be power of 2."); + + // Check the constancy of the mask support to load the memory in 2D block. + if (maskConstancyVer < tileHeight) { + unsigned minTileHeight = + mlir::ceil(threadsPerWarp, tileWidth); + if (maskConstancyVer < minTileHeight) + return failure(); + + unsigned numBasesForPackedVals = __builtin_ctz(numPackedVals); + unsigned numBasesForTileWidth = + __builtin_ctz(mlir::ceil(tileWidth, threadsPerWarp)); + unsigned numBasesForNewTileHeight = + __builtin_ctz(maskConstancyVer / minTileHeight); + unsigned numBasesForOldTileHeight = + __builtin_ctz(tileHeight / minTileHeight); + unsigned numBasesForVBlocks = __builtin_ctz(vBlocks); + + std::vector> rearangeMap; + for (int i = 0; i < __builtin_ctz(numElems); i++) { + rearangeMap.emplace_back().push_back(1 << i); + } + + // Rotate the bases of the ole tile height after the vBlocks. + unsigned rotateStart = numBasesForPackedVals + numBasesForTileWidth + + numBasesForNewTileHeight; + unsigned rotateNum = + numBasesForOldTileHeight - numBasesForNewTileHeight; + std::rotate(rearangeMap.begin() + rotateStart, + rearangeMap.begin() + rotateStart + rotateNum, + rearangeMap.begin() + rotateStart + rotateNum + + numBasesForVBlocks); + + LinearLayout rearangeMapping( + {{kRegister, rearangeMap}}, + {{kRegister, regMapping.getInDimSize(kRegister)}}, + /*requireSurjective=*/true); + tileHeight = maskConstancyVer; + assert(((tileHeight - 1) & tileHeight) == 0 && + "the tileHeight has to be power of 2."); + rearangeMapping = rearangeMapping.compose(regMapping); + } + } // Get the LLVM values for `other` + Value llOther = adaptor.getOther(); DenseElementsAttr constAttr; if (other && isa(eltTy) && matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && isa(constAttr.getElementType())) { - otherIsSplatConstInt = true; - splatVal = constAttr.getSplatValue().getSExtValue(); - } - if (other) { + int64_t splatVal = constAttr.getSplatValue().getSExtValue(); + if (splatVal != 0) + otherElems = + SmallVector(numElems, b.int_val(elemSizeInBits, splatVal)); + } else if (other) { otherElems = unpackLLElements(loc, llOther, rewriter); } - // re-arrange the ptrs and masks to for large 2D block IO. - // Layout is unrelated to the scalar type. - SmallVector> offsets = - mlir::emitOffsetForLayout(encoding, tensorType); - for (size_t i = 0; i < ptrElems.size(); ++i) { - SmallVector offset = offsets[i]; - ptrs[offset] = ptrElems[i]; - if (llMask) - masks[offset] = maskElems[i]; - if (otherElems.size()) - others[offset] = otherElems[i]; - } + baseWidth = b.i32_val( + std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8))); + // If the stride is 0, we want to load only the first row. + int stride = getStride(ptr, 0); + baseHeightInt = (stride == 0 ? 1 : tileHeight); + baseHeight = b.i32_val(baseHeightInt); + pitch = getPitch(rewriter, ptr, elemSizeInBits); + if (!pitch) + return failure(); - unsigned numOperandsPer2DLoadM, numOperandsPer2DLoadN; - if (!isTransposeRequired) { - // Set the number of operands per 2D load to the maximum number from - // layout information. The number will be adjusted to fit the - // tensor pointers's shape, constancy and contiguity. + int64_t numElemsPerLoad = + tileHeight * tileWidth * numPackedVals * vBlocks / threadsPerWarp; + + int64_t numValuesPerLoad = numElemsPerLoad / numPackedVals; + Type packedType = IntegerType::get(ctx, packedElemSizeInBits); + Type load2DGenXType = LLVM::getVectorType(packedType, numValuesPerLoad); + Type unpackedType = LLVM::getVectorType(eltTy, numElemsPerLoad); + + bool useVNNIFormat = false; + Type packedDPASOperandType; + if (hasDpasEncoding(tensorType) || hasDotDpasEncoding(tensorType)) { + + // There are three types for block load here for DPAS layout + // (Otherwise there are only two types are used.): + // 1. load2DGenXType - + // 2. packedDPASOperandType - + // 3. unpackedType - + // + // clang-format off + // The block load generated op: + // %0 = load_2d %ptr : + // %1 = shufflevector %0, %0, <8 x i32> + // %2 = shufflevector %0, %0, <8 x i32> + // %3 = bitcast %1 : -> + // %4 = bitcast %2 : -> + // + // clang-format on + // + // The DPAS generated op: + // clang-format off + // + // %5 = bitcast %3 : -> + // %6 = bitcast %4 : -> + // %7 = dpas %5, %6, %other : , , + // clang-format on + // + // The LLVM optimizer will remove the redundant expression of pack/unpack + // elements pair and bitcast pair. The final IR would be simple for dot + // product: + // clang-format off + // %0 = load_2d %ptr : + // %1 = shufflevector %0, %0, <8 x i32> + // %2 = shufflevector %0, %0, <8 x i32> + // %3 = dpas %1, %2, %other : , , + // clang-format on + // The packedDPASOperandType and the shufflevector operation define the + // computation flow for the dot product. + + DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); + auto dpasLayout = getDpasLayout(tensorType); switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: - numOperandsPer2DLoadM = repCluster[dimOuter]; - numOperandsPer2DLoadN = numReps[dimInner]; - break; - case DpasEncodingAttr::OpIdx::OperandB: - numOperandsPer2DLoadM = numReps[dimInner]; - numOperandsPer2DLoadN = repCluster[dimOuter]; - break; - case DpasEncodingAttr::OpIdx::OperandC: - numOperandsPer2DLoadM = repCluster[dimOuter]; - numOperandsPer2DLoadN = repCluster[dimInner]; - break; + case DpasEncodingAttr::OpIdx::OperandA: { + unsigned elemsPerLanePerDPASInst = + product(dpasLayout.getDPASInstShapeA()) / threadsPerWarp; + // Block 2D contain at least one DotOp A. + if (numElemsPerLoad >= elemsPerLanePerDPASInst) { + packedDPASOperandType = LLVM::getVectorType( + packedType, elemsPerLanePerDPASInst / numPackedVals); + unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst); + } + } break; + case DpasEncodingAttr::OpIdx::OperandB: { + assert(numPackedVals == 1 && + "invalid number of packed values for DPAS operand B."); + unsigned elemsPerLanePerDPASInst = + product(dpasLayout.getDPASInstShapeB()) / threadsPerWarp; + // Block 2D contain at least one DotOp B. + if (numElemsPerLoad >= elemsPerLanePerDPASInst) { + unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); + unsigned sysDepth = dpasLayout.getSystolicDepth(); + if (tileHeight >= (opsPerChannel * sysDepth) && + ((opsPerChannel == 4 && elemSizeInBits == 8) || + (opsPerChannel == 2 && elemSizeInBits == 16))) { + // Use the VNNI packing format for DotOp B layout. + numValuesPerLoad = numElemsPerLoad / opsPerChannel; + packedType = i32_ty; + load2DGenXType = LLVM::getVectorType(packedType, numValuesPerLoad); + packedDPASOperandType = LLVM::getVectorType( + packedType, elemsPerLanePerDPASInst / opsPerChannel); + useVNNIFormat = true; + } else { + packedDPASOperandType = + LLVM::getVectorType(IntegerType::get(ctx, packedElemSizeInBits), + elemsPerLanePerDPASInst / numPackedVals); + } + unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst); + } + } break; + case DpasEncodingAttr::OpIdx::OperandC: { + unsigned elemsPerLanePerDPASInst = + product(dpasLayout.getDPASInstShapeC()) / threadsPerWarp; + // Block 2D contain at least one DotOp C. + if (numElemsPerLoad >= elemsPerLanePerDPASInst) { + packedDPASOperandType = LLVM::getVectorType( + packedType, elemsPerLanePerDPASInst / numPackedVals); + unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst); + } + } break; default: - llvm_unreachable("unknown DPAS operands index type."); - break; + llvm_unreachable("unexpected OpIdx::OperandC"); } - } else { - if (opIdx == DpasEncodingAttr::OpIdx::OperandA) - return failure(); - - if (!usePackedType) - return failure(); - - std::swap(tileHeight, tileWidth); - - // We can decompose the matrix returned by transposed large 2d load - // when threads per warp < column size. Otherwise we have to load one - // operand per inst. - // Note: the tileHeight and numOperandsPer2DLoadM are the column size - // now. - numOperandsPer2DLoadM = - (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1; - // The transpose 2d load only support 1 operand per inst on column. - // (vBlocks = 1) - numOperandsPer2DLoadN = 1; } - // adjust the mask constancy to fit the 2D load. - numOperandsPer2DLoadM = - std::min(numOperandsPer2DLoadM, maskConstancyHor / instWidth); - numOperandsPer2DLoadN = - std::min(numOperandsPer2DLoadN, maskConstancyVer / instHeight); - - // PVC 2D load supports 32 rows at most. Load multiple dot operands in by - // enlarging the tileHeight. - numOperandsPer2DLoadM = std::min(numOperandsPer2DLoadM, 32 / tileHeight); - - // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands - // by enlarging the vBlocks. - constexpr int MAX_WIDTH = 64; - unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8; - if (totalBytesPerRowPerDPASOp > MAX_WIDTH) - return failure(); - numOperandsPer2DLoadN = - std::min(numOperandsPer2DLoadN, MAX_WIDTH / totalBytesPerRowPerDPASOp); - - tileHeight = instHeight * numOperandsPer2DLoadM; - tileWidth = instWidth; - vBlocks = numOperandsPer2DLoadN; - - numOperandsOuterDimPerLoad = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsPer2DLoadM - : numOperandsPer2DLoadN; - numOperandsInnerDimPerLoad = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsPer2DLoadN - : numOperandsPer2DLoadM; - - if (isTransposeRequired) - std::swap(numOperandsOuterDimPerLoad, numOperandsInnerDimPerLoad); - - unsigned numLoadPerOutRepCluster = - mlir::ceil(repCluster[dimOuter], numOperandsOuterDimPerLoad); - unsigned numLoadPerInnerRepCluster = - mlir::ceil(repCluster[dimInner], numOperandsInnerDimPerLoad); - - unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * - numOperandsOuterDimPerLoad * - numOperandsInnerDimPerLoad; - Type load2DGenXType = - LLVM::getVectorType(loadResultElemType, numValuesPerLoad); - - // Step 4: Generates the load instruction. - // The stride for the tile replicates. - unsigned numRepOuter; - unsigned numRepInner; - unsigned repOuterStride = warpShape[dimOuter] * outerDimWarpNum; - unsigned repInnerStride; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: - case DpasEncodingAttr::OpIdx::OperandB: - numRepOuter = numReps[dimOuter]; - numRepInner = - mlir::ceil(numReps[dimInner], numOperandsInnerDimPerLoad); - repInnerStride = warpShape[dimInner] * numOperandsInnerDimPerLoad; - break; - case DpasEncodingAttr::OpIdx::OperandC: - numRepOuter = numReps[dimOuter]; - numRepInner = numReps[dimInner]; - repInnerStride = warpShape[dimInner] * innerDimWarpNum; - break; - default: - llvm_unreachable("unknown DPAS operands index type."); - break; - } - - Value pitch = getPitch(rewriter, ptr, elemSizeInBits); - if (!pitch) - return failure(); - - // If the stride is 0, we want to load only the first row. - int stride = getStride(ptr, 0); - unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight); - Value baseHeight = b.i32_val(baseHeightInt); - Value baseWidth = - b.i32_val(std::max(64u, vBlocks * tileWidth * (elemSizeInBits / 8))); - - StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - - const unsigned originalElemBits = elemSizeInBits; + SmallVector unpackedLoadedVals(numElems); + for (size_t elemIdx = 0; elemIdx < numElems; elemIdx += numElemsPerLoad) { + unsigned registerIdx = regMapping.apply({{kRegister, elemIdx}})[0].second; + Value addrElem = ptrElems[registerIdx]; + addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0); + + Value offsetX = b.i32_val(0); + Value offsetY = b.i32_val(0); + Value pred; + // Use the top-left address and mask of the block to store the data. + // (The first value refer by the elemIdx.) + if (maskElems.size()) { + assert(maskElems.size() == otherElems.size() && + "Invalid size of the masks."); + pred = targetInfo.shuffleIdx(rewriter, loc, maskElems[registerIdx], 0); + } - LDBG("Block io tile shape: [" - << tileHeight << ", " << tileWidth << "], vblocks: " << vBlocks - << ", numOperandsPerLoad: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad) - << "], number loads per repCluster: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numLoadPerOutRepCluster - : numLoadPerInnerRepCluster) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numLoadPerInnerRepCluster - : numLoadPerOutRepCluster) - << "], number repCluster: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB ? numRepOuter - : numRepInner) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB ? numRepInner - : numRepOuter) - << "]"); + if (pred) { + // We leverage the GPU block I/O hardware out-of-bound protection + // feature by setting the offset to an invalid value when 'pred' + // is false (the HW will not read out-of-bounds values). + offsetY = b.select(pred, offsetY, baseHeight); + } - ValueTable loadVals; - for (int inner = 0; inner < numRepInner; ++inner) { - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int loadInner = 0; loadInner < numLoadPerInnerRepCluster; - ++loadInner) { - for (int loadOuter = 0; loadOuter < numLoadPerOutRepCluster; - ++loadOuter) { - unsigned offsetOuter = - outer * repOuterStride + loadOuter * dpasInstShape[dimOuter] * - numOperandsOuterDimPerLoad; - unsigned offsetInner = - inner * repInnerStride + loadInner * dpasInstShape[dimInner] * - numOperandsInnerDimPerLoad; - unsigned offsetM = - (opIdx != DpasEncodingAttr::OpIdx::OperandB ? offsetOuter - : offsetInner); - unsigned offsetN = - (opIdx != DpasEncodingAttr::OpIdx::OperandB ? offsetInner - : offsetOuter); - - LDBG("Block load iterator: inner: " - << inner << ", outer:" << outer << ", loadInner:" << loadInner - << ", loadOuter:" << loadOuter << " offset: [" << offsetM - << ", " << offsetN << "]"); - - Value offsetY = b.i32_val(0); - Value pred; - if (llMask) { - assert(masks.size() && "Invalid size of the masks."); - pred = targetInfo.shuffleIdx(rewriter, loc, - masks[{offsetM, offsetN}], 0); - // We leverage the GPU block I/O hardware out-of-bound protection - // feature by setting the offset to an invalid value when 'pred' - // is false (the HW will not read out-of-bounds values). Later on, - // after issuing the 2d block read operation, we will select the - // result of the load only if the mask evaluate to true, otherwise - // we will use 'other'. - offsetY = b.select(pred, offsetY, baseHeight); + Value ret = rewriter.create( + loc, load2DGenXType, + /*ptr*/ addrElem, + /*base_width*/ baseWidth, + /*base_height*/ baseHeight, + /*base_pitch*/ pitch, + /*x*/ offsetX, + /*y*/ offsetY, + /*elem_size_in_bits*/ packedElemSizeInBits, + /*tile_width*/ tileWidth, + /*tile_height*/ tileHeight, + /*v_blocks*/ vBlocks, + /*transpose*/ false, + /*vnni_transform*/ useVNNIFormat); + + // When strides[0] is 0, we only want to load the first row, so we + // set the base height to be 1. If tile height is bigger than 1, + // then only the first row contain valid data. To ensure the entire + // tile is filled with valid data, we must replicate the first row + // throughout the tile. + if (baseHeightInt < tileHeight && baseHeightInt == 1) { + unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks; + SmallVector shuffleIndices(numValuesPerLoad); + + // Create a vector to store the data of the first index of each + // matrix. + VectorType vecTy = vec_ty(packedType, vBlocks); + Value firstIndexVec = b.undef(vecTy); + + for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad; + ++valueIndex) { + unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix; + // Handle case where an index spans two rows. + if (valueIndex % numIndicesPerMatrix == 0) { + Value oldVal = b.extract_element(ret, b.i32_val(valueIndex)); + Value newVal = oldVal; + if (tileWidth < threadsPerWarp) { + assert(tileWidth * 2 == threadsPerWarp && + "Expecting tileWidth to be 2x threadsPerWarp"); + Value threadId = getThreadId(rewriter, loc); + newVal = + targetInfo.shuffleIdx(rewriter, loc, oldVal, + b.urem(threadId, b.i32_val(tileWidth))); } + firstIndexVec = + b.insert_element(firstIndexVec.getType(), firstIndexVec, newVal, + b.i32_val(firstIndexVecIdx)); + } - // Use the top-left address of the block to load the data. - Value addrElem = - b.bitcast(ptrs[{offsetM, offsetN}], ptr_ty(ctx, 1 /*global*/)); - addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0); - - Value ret = rewriter.create( - loc, load2DGenXType, - /*ptr*/ addrElem, - /*base_width*/ baseWidth, - /*base_height*/ baseHeight, - /*base_pitch*/ pitch, - /*x*/ b.i32_val(0), - /*y*/ offsetY, - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ tileWidth, - /*tile_height*/ tileHeight, - /*v_blocks*/ vBlocks, - /*transpose*/ false, - /*vnni_transform*/ - (usePackedType && opIdx == DpasEncodingAttr::OpIdx::OperandB && - !isTransposeRequired && originalElemBits != 32)); - - // When strides[0] is 0, we only want to load the first row, so we - // set the base height to be 1. If tile height is bigger than 1, - // then only the first row contain valid data. To ensure the entire - // tile is filled with valid data, we must replicate the first row - // throughout the tile. - if (baseHeightInt < tileHeight && baseHeightInt == 1) { - unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks; - SmallVector shuffleIndices(numValuesPerLoad); - - // Create a vector to store the data of the first index of each - // matrix. - VectorType vecTy = vec_ty(loadResultElemType, vBlocks); - Value firstIndexVec = b.undef(vecTy); - - for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad; - ++valueIndex) { - unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix; - // Handle case where an index spans two rows. - if (valueIndex % numIndicesPerMatrix == 0) { - Value oldVal = b.extract_element(ret, b.i32_val(valueIndex)); - Value newVal = oldVal; - if (tileWidth < threadsPerWarp) { - assert(tileWidth * 2 == threadsPerWarp && - "Expecting tileWidth to be 2x threadsPerWarp"); - Value threadId = getThreadId(rewriter, loc); - newVal = targetInfo.shuffleIdx( - rewriter, loc, oldVal, - b.urem(threadId, b.i32_val(tileWidth))); - } - firstIndexVec = - b.insert_element(firstIndexVec.getType(), firstIndexVec, - newVal, b.i32_val(firstIndexVecIdx)); - } - - shuffleIndices[valueIndex] = firstIndexVecIdx; - } - DenseI32ArrayAttr attr = - rewriter.getDenseI32ArrayAttr(shuffleIndices); - ret = rewriter.create( - loc, load2DGenXType, firstIndexVec, firstIndexVec, attr); - } + shuffleIndices[valueIndex] = firstIndexVecIdx; + } + DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(shuffleIndices); + ret = rewriter.create( + loc, load2DGenXType, firstIndexVec, firstIndexVec, attr); + } - if (others.size()) { - assert(masks.size() == others.size() && - "The mask value has to be provided when " - "the other value is provided."); - VectorType vecTy = - vec_ty(eltTy, numValuesPerLoad * packedElemsNum); - - Value v = b.undef(vecTy); - unsigned nWords = 0; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int i = 0; i < tileHeight; ++i) { - unsigned numColPerPackedValue = - opIdx == DpasEncodingAttr::OpIdx::OperandA - ? packedElemsNum - : 1; - unsigned numPackedValuesPerRow = mlir::ceil( - (tileWidth / numColPerPackedValue), threadsPerWarp); - for (int col = 0; col < numPackedValuesPerRow; ++col) { - for (int packedCol = 0; packedCol < numColPerPackedValue; - ++packedCol) { - unsigned N = packedCol + - col * threadsPerWarp * numColPerPackedValue + - vblk * tileWidth + offsetN; - unsigned M = i + offsetM; - Value falseVal = others[{M, N}]; - Value sVal = createIndexAttrConstant( - rewriter, loc, typeConverter->getIndexType(), - nWords++); - v = b.insert_element(vecTy, v, falseVal, sVal); - } - } - } - Value others = b.bitcast(v, load2DGenXType); - ret = b.select(pred, ret, others); - } + unsigned numValsPerDPASOperand = + packedDPASOperandType + ? LLVM::getVectorNumElements(packedDPASOperandType) + .getKnownMinValue() + : numValuesPerLoad; + + unsigned numElemsPerUnpackedType = + LLVM::getVectorNumElements(unpackedType).getKnownMinValue(); + unsigned numOperandsPerLoad = numValuesPerLoad / numValsPerDPASOperand; + for (size_t opsIdx = 0; opsIdx < numOperandsPerLoad; ++opsIdx) { + + Value unpackedVal; + if (numValsPerDPASOperand != numValuesPerLoad) { + // Decompose the return value to multiple DPAS operands. + SmallVector indices(numValsPerDPASOperand); + for (int i = 0; i < numValsPerDPASOperand; ++i) { + indices[i] = opsIdx * numValsPerDPASOperand + i; + } + DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); + Value dpasOperand = rewriter.create( + loc, packedDPASOperandType, ret, ret, attr); - unsigned numOperandsM = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad; - unsigned numOperandsN = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad; - - // Split the return matrix by large 2d block io size into multiple - // DPAS operands. - assert(numOperandsN >= vBlocks && - "numOperandsN has to be >= vBlocks"); - unsigned numOperandsPerVBlockN = numOperandsN / vBlocks; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int row = 0; row < numOperandsM; ++row) - for (int col = 0; col < numOperandsPerVBlockN; ++col) { - - unsigned operandStartOffset = (vblk * numOperandsM + row) * - numOperandsPerVBlockN * - packedElemsPerLanePerDPASInst; - - SmallVector indices(packedElemsPerLanePerDPASInst); - for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; - ++elemIdx) { - indices[elemIdx] = operandStartOffset + - elemIdx * numOperandsPerVBlockN + col; - } + unpackedVal = b.bitcast(dpasOperand, unpackedType); - LLVM_DEBUG({ - DBGS() << "shuffle idx: ["; - for (int elemIdx = 0; - elemIdx < packedElemsPerLanePerDPASInst; ++elemIdx) { - llvm::dbgs() << indices[elemIdx] << ", "; - } - llvm::dbgs() << "]\n"; - }); + } else { + unpackedVal = b.bitcast(ret, unpackedType); + } - DenseI32ArrayAttr attr = - rewriter.getDenseI32ArrayAttr(indices); - Value loadVal = rewriter.create( - loc, packedDPASOperandType, ret, ret, attr); - - // Save the decomposed vals to the map; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandC: - case DpasEncodingAttr::OpIdx::OperandA: { - unsigned o = outer * numLoadPerOutRepCluster * - numOperandsOuterDimPerLoad + - loadOuter * numOperandsOuterDimPerLoad + row; - unsigned i = inner * numLoadPerInnerRepCluster * - numOperandsInnerDimPerLoad + - loadInner * numOperandsInnerDimPerLoad + - vblk * numOperandsPerVBlockN + col; - - LDBG("insert: [" << o << ", " << i << "]"); - loadVals[{o, i}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - unsigned o = outer * numLoadPerOutRepCluster * - numOperandsOuterDimPerLoad + - loadOuter * numOperandsOuterDimPerLoad + - vblk * numOperandsPerVBlockN + col; - unsigned i = inner * numOperandsInnerDimPerLoad + row; - LDBG("insert: [" << o << ", " << i << "]"); - loadVals[{o, i}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - default: { - llvm_unreachable("unknown DPAS operands index type."); - } break; - } - } + // select with others on unpacked type. + if (otherElems.size()) { + assert(maskElems.size() == otherElems.size() && + "Invalid size of the masks."); + Value other = b.undef(unpackedType); + for (size_t i = 0; i < numElemsPerUnpackedType; ++i) { + unsigned registerIdx = + regMapping + .apply( + {{kRegister, + elemIdx + opsIdx * numElemsPerUnpackedType + i}})[0] + .second; + Value falseVal = otherElems[registerIdx]; + other = b.insert_element(other, falseVal, b.i32_val(i)); } + unpackedVal = b.select(pred, unpackedVal, other); } - } - } - // Step 5: Unpack the load values. - // Extract the value returned by the load ops. And put the values in the - // expected order for the layout. - SmallVector unpackedLoadedVals; - for (int outer = 0; outer < numReps[dimOuter]; ++outer) { - for (int inner = 0; inner < numReps[dimInner]; ++inner) { - for (int repOuter = 0; repOuter < repCluster[dimOuter]; ++repOuter) { - for (int repInner = 0; repInner < repCluster[dimInner]; ++repInner) { - unsigned o = outer * repCluster[dimOuter] + repOuter; - unsigned i = inner * repCluster[dimInner] + repInner; - LDBG("extract: [" << o << ", " << i << "]"); - Value loadVal = loadVals.at({o, i}); - VectorType loadTy = cast(loadVal.getType()); - for (int i = 0; i < loadTy.getNumElements(); ++i) { - auto val = b.extract_element(loadVal, b.i32_val(i)); - unpackedLoadedVals.push_back(val); - } - loadVals.erase({o, i}); - } + for (int i = 0; i < numElemsPerUnpackedType; ++i) { + unsigned registerIdx = + regMapping + .apply({{kRegister, + elemIdx + opsIdx * numElemsPerUnpackedType + i}})[0] + .second; + Value val = b.extract_element(unpackedVal, b.i32_val(i)); + + // TODO: this is another implementation for othterElems, which is with + // more `select` ops but less register pressure. Need to compare the + // impact of the two implementations. + // if (otherElems.size()) { + // val = b.select(pred, val, otherElems[registerIdx]); + // } + unpackedLoadedVals[registerIdx] = val; } } } - assert(loadVals.empty() && "not all loaded values is unpacked."); - Type llvmResultStructTy = typeConverter->convertType(op.getType()); Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);