From 68d18dda370430979ce554ac4b4ed18d87d5ced9 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 26 Mar 2025 15:05:11 +0000 Subject: [PATCH 1/2] Change 2D block store from OCL to SPV Signed-off-by: Whitney Tsang --- ...tritongpu_to_llvm_intel_advanced_path.mlir | 6 +-- .../tritongen-2Dblockstore-to-llvm.mlir | 31 ++++++++++---- test/TritonIntelGPU/blockptr_store.mlir | 28 ++++++------- .../TritonGENToLLVM/TritonGENToLLVMPass.cpp | 40 ++++++++++--------- 4 files changed, 62 insertions(+), 43 deletions(-) diff --git a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir index 33a10b5c0d..a0e582b229 100644 --- a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir +++ b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir @@ -4,7 +4,7 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup // CHECK-DAG: llvm.func spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} // CHECK-DAG: llvm.func spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv(i32, i32, i32, i32, !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} // CHECK-DAG: llvm.func spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv(i32, i32, i32, i32, !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} - // CHECK-DAG: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} + // CHECK-DAG: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return} // CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects, no_unwind} tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { @@ -102,7 +102,7 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup cf.br ^bb1(%119, %71, %115, %117, %118 : i32, tensor<8x16xf32>, !tt.ptr, 1>, !tt.ptr, 1>, !tt.ptr, 1>) ^bb3: %120 = tt.make_tensor_ptr %arg2, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%21, %36] {order = array} : , 1> - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(%arg2, {{.*}} + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %arg2, {{.*}} tt.store %120, %41 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1> tt.return } @@ -140,7 +140,7 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<8x16xf16> %0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array} : > - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%arg0, {{.*}}) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %arg0, {{.*}}) tt.store %0, %cst {boundaryCheck = array} : !tt.ptr> tt.return } diff --git a/test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir b/test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir index 9b4f81f926..5878044b06 100644 --- a/test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir +++ b/test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s -// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_write_8b_8r16x1cPU3AS1viiiDv2_iPh(!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return} llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) { // CHECK: llvm.func @triton_gen.2Dblockstore(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: vector<8xi8>) { @@ -20,9 +20,13 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base // CHECK-DAG: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32> // CHECK-NEXT: [[COORD0:%.*]] = llvm.insertelement [[ADD_1]], [[UNDEF]][[[ZERO]] : i32] : vector<2xi32> // CHECK-NEXT: [[COORD1:%.*]] = llvm.insertelement %arg5, [[COORD0]][[[ONE]] : i32] : vector<2xi32> - // CHECK-NEXT: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_write_8b_8r16x1cPU3AS1viiiDv2_iPh(%arg0, [[ADD_0]], %arg2, %arg3, [[COORD1]], [[STOREVALPTR]]) - // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, Uncached, 0>> - // CHECK-SAME: : (!llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> () + // CHECK-NEXT: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[STOREVALPTR]], %arg0, [[ADD_0]], %arg2, %arg3, [[COORD1]]) + // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 5>, #triton_gen.store_cache_control<1, Uncached, 5>> + // CHECK-SAME: : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, cache_control=L1UC_L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi8>) llvm.return } @@ -30,7 +34,12 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base // ----- llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) { - // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_write_8b_8r32x1cPU3AS1viiiDv2_iPt(%arg0, %{{.*}}, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> () + // CHECK: llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[DEST:%.*]], %arg0, %{{.*}}, %arg2, %arg3, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) llvm.return } @@ -38,7 +47,11 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base // ----- llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) { - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%arg0, %{{.*}}, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> () + // CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[DEST:%.*]], %arg0, %{{.*}}, %arg2, %arg3, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) llvm.return } @@ -46,7 +59,11 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base // ----- llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi32>) { - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(%arg0, %{{.*}}, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> () + // CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[DEST:%.*]], %arg0, %{{.*}}, %arg2, %arg3, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) llvm.return } diff --git a/test/TritonIntelGPU/blockptr_store.mlir b/test/TritonIntelGPU/blockptr_store.mlir index d3b2e2d497..ba5dca6cb8 100644 --- a/test/TritonIntelGPU/blockptr_store.mlir +++ b/test/TritonIntelGPU/blockptr_store.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm -// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return} #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> #dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> #dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> @@ -39,16 +39,16 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: llvm.mul %[[VAL_1]], %[[CST_16]] : i32 // CHECK: llvm.mlir.undef : vector<8xf16> // CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16> - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt{{.*}} + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i{{.*}} // CHECK: llvm.mlir.undef : vector<8xf16> // CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16> - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt{{.*}} + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i{{.*}} // CHECK: llvm.mlir.undef : vector<8xf16> // CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16> - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt{{.*}} + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i{{.*}} // CHECK: llvm.mlir.undef : vector<8xf16> // CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16> - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt{{.*}} + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i{{.*}} tt.store %13, %12 {boundaryCheck = array} : !tt.ptr> tt.return } @@ -56,7 +56,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // ----- -// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return} +// CHECK: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return} #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: llvm.func spir_kernelcc @dpas_layout_2d_store_rep_cluster_4_2( @@ -214,7 +214,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: %[[VAL_204:.*]] = llvm.mlir.undef : vector<2xi32> // CHECK: %[[VAL_205:.*]] = llvm.insertelement %{{.*}}, %[[VAL_204]]{{\[}}%[[VAL_203]] : i32] : vector<2xi32> // CHECK: %[[VAL_206:.*]] = llvm.insertelement %[[VAL_198]], %[[VAL_205]]{{\[}}%[[VAL_202]] : i32] : vector<2xi32> - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], %[[VAL_206]], %[[VAL_201]]) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_201]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], %[[VAL_206]]) // COM: replica [0, 1] // CHECK: %[[VAL_207:.*]] = llvm.mlir.constant(16 : i32) : i32 @@ -232,7 +232,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: %[[VAL_229:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAL_230:.*]] = llvm.alloca %[[VAL_229]] x i16 : (i32) -> !llvm.ptr // CHECK: llvm.store %[[VAL_226]], %[[VAL_230]] : vector<8xi16>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}, %[[VAL_230]]) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_230]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}) // COM: replica [1, 0] // CHECK: %[[VAL_236:.*]] = llvm.mlir.constant(8 : i32) : i32 @@ -252,7 +252,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: %[[VAL_260:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAL_261:.*]] = llvm.alloca %[[VAL_260]] x i16 : (i32) -> !llvm.ptr // CHECK: llvm.store %[[VAL_257]], %[[VAL_261]] : vector<8xi16>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}, %[[VAL_261]]) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_261]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}) // COM: replica [1, 1] // CHECK: %[[VAL_267:.*]] = llvm.mlir.constant(16 : i32) : i32 @@ -270,7 +270,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: %[[VAL_289:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAL_290:.*]] = llvm.alloca %[[VAL_289]] x i16 : (i32) -> !llvm.ptr // CHECK: llvm.store %[[VAL_286]], %[[VAL_290]] : vector<8xi16>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}, %[[VAL_290]]) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_290]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}) // COM: replica [2, 0] // CHECK: %[[VAL_296:.*]] = llvm.mlir.constant(16 : i32) : i32 @@ -290,7 +290,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: %[[VAL_320:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAL_321:.*]] = llvm.alloca %[[VAL_320]] x i16 : (i32) -> !llvm.ptr // CHECK: llvm.store %[[VAL_317]], %[[VAL_321]] : vector<8xi16>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}, %[[VAL_321]]) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_321]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}) // COM: replica [2, 1] // CHECK: %[[VAL_327:.*]] = llvm.mlir.constant(16 : i32) : i32 @@ -308,7 +308,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: %[[VAL_349:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAL_350:.*]] = llvm.alloca %[[VAL_349]] x i16 : (i32) -> !llvm.ptr // CHECK: llvm.store %[[VAL_346]], %[[VAL_350]] : vector<8xi16>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}, %[[VAL_350]]) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_350]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}) // COM: replica [3, 0] // CHECK: %[[VAL_356:.*]] = llvm.mlir.constant(24 : i32) : i32 @@ -328,7 +328,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: %[[VAL_380:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAL_381:.*]] = llvm.alloca %[[VAL_380]] x i16 : (i32) -> !llvm.ptr // CHECK: llvm.store %[[VAL_377]], %[[VAL_381]] : vector<8xi16>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}, %[[VAL_381]]) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_381]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}) // COM: replica [3, 1] // CHECK: %[[VAL_387:.*]] = llvm.mlir.constant(16 : i32) : i32 @@ -346,7 +346,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // CHECK: %[[VAL_409:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK: %[[VAL_410:.*]] = llvm.alloca %[[VAL_409]] x i16 : (i32) -> !llvm.ptr // CHECK: llvm.store %[[VAL_406]], %[[VAL_410]] : vector<8xi16>, !llvm.ptr - // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}, %[[VAL_410]]) + // CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_410]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}}) tt.store %13, %cst {boundaryCheck = array} : !tt.ptr> tt.return diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 04a158e913..c1ead321b5 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -574,41 +574,43 @@ struct TritonMatrix2DBlockStoreLowering b.i32_val(storeValType.getNumElements())); rewriter.create(loc, op.getStoredVal(), storeValPtr); - std::string fnName = "intel_sub_group_2d_block_write_"; - fnName += std::to_string(op.getElemSizeInBits()) + "b_" + - std::to_string(op.getTileHeight()) + "r" + - std::to_string(op.getTileWidth()) + "x" + - std::to_string(op.getVBlocks()) + "c"; - fnName = "_Z" + std::to_string(fnName.size()) + fnName + "PU3AS1viiiDv2_iP"; - unsigned storeValBitWidth = - storeValType.getElementType().getIntOrFloatBitWidth(); - fnName += (storeValBitWidth == 32) ? "j" - : (storeValBitWidth == 16) ? "t" - : "h"; + std::string fnName = "__spirv_Subgroup2DBlockStoreINTEL"; auto [baseWidth, offsetX] = computeAlignedBaseWidthAndOffset(op, rewriter); VectorType vecType = vec_ty(i32_ty, 2); + SmallVector argTypes{i32_ty, i32_ty, i32_ty, i32_ty, + ptr_ty(ctx), ptr_ty(ctx, 1), i32_ty, i32_ty, + i32_ty, vecType}; + fnName = intel::mangle(fnName, argTypes); + Value byteCoord = b.insert_element( vecType, b.insert_element(vecType, b.undef(vecType), offsetX, b.i32_val(0)), op.getY(), b.i32_val(1)); - SmallVector argTypes{ptr_ty(ctx, 1), i32_ty, i32_ty, - i32_ty, vecType, ptr_ty(ctx)}; - SmallVector args{op.getPtr(), baseWidth, op.getBaseHeight(), - op.getBasePitch(), byteCoord, storeValPtr}; + + SmallVector args{b.i32_val(op.getElemSizeInBits() / 8), + b.i32_val(op.getTileWidth()), + b.i32_val(op.getTileHeight()), + b.i32_val(op.getVBlocks()), + storeValPtr, + op.getPtr(), + baseWidth, + op.getBaseHeight(), + op.getBasePitch(), + byteCoord}; std::array, 4> paramAttrs{ - std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), - std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), - std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()), + std::make_pair(4, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(4, LLVM::LLVMDialect::getReadonlyAttrName()), }; LLVM::CallOp call = intel::createDeviceFunctionCall( rewriter, fnName, void_ty(ctx), argTypes, args, paramAttrs, intel::noUnwindWillReturnAttrs); - constexpr uint32_t ptrOperandIndex = 0; + constexpr uint32_t ptrOperandIndex = 5; if (std::optional optCacheControls = storeCacheControlToCacheControls(rewriter, op.getCacheControl(), ptrOperandIndex)) { From d905198023bb2a4a8498457016b4ff1006667c97 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 22 May 2025 03:02:09 +0000 Subject: [PATCH 2/2] Fix lit Signed-off-by: Whitney Tsang --- .../tritongen-2Dblockstore-to-llvm.mlir | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir b/test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir index 5878044b06..36f8b58a84 100644 --- a/test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir +++ b/test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir @@ -20,10 +20,10 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base // CHECK-DAG: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32> // CHECK-NEXT: [[COORD0:%.*]] = llvm.insertelement [[ADD_1]], [[UNDEF]][[[ZERO]] : i32] : vector<2xi32> // CHECK-NEXT: [[COORD1:%.*]] = llvm.insertelement %arg5, [[COORD0]][[[ONE]] : i32] : vector<2xi32> - // CHECK-NEXT: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-DAG: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[STOREVALPTR]], %arg0, [[ADD_0]], %arg2, %arg3, [[COORD1]]) // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 5>, #triton_gen.store_cache_control<1, Uncached, 5>> // CHECK-SAME: : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () @@ -34,12 +34,12 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base // ----- llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) { - // CHECK: llvm.mlir.constant(1 : i32) : i32 - // CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(32 : i32) : i32 - // CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[DEST:%.*]], %arg0, %{{.*}}, %arg2, %arg3, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () + // CHECK-COUNT-2: llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[TileWidth:%.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK-DAG: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[DEST:%.*]], %arg0, %{{.*}}, %arg2, %arg3, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) llvm.return } @@ -47,10 +47,11 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base // ----- llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) { + // CHECK: llvm.mlir.constant(2 : i32) : i32 // CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-DAG: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[DEST:%.*]], %arg0, %{{.*}}, %arg2, %arg3, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) llvm.return @@ -59,10 +60,11 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base // ----- llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi32>) { + // CHECK: llvm.mlir.constant(4 : i32) : i32 // CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-DAG: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[DEST:%.*]], %arg0, %{{.*}}, %arg2, %arg3, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) llvm.return