Skip to content

Commit dca7748

Browse files
Change DPAS from OCL to SPV builtin (#3732)
SPEC: [SPV_INTEL_subgroup_matrix_multiply_accumulate](https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_subgroup_matrix_multiply_accumulate.html) Benchmark CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15145720901 (no geomean regression) Closes #3731 Signed-off-by: Whitney Tsang <[email protected]>
1 parent 58f5923 commit dca7748

File tree

7 files changed

+118
-78
lines changed

7 files changed

+118
-78
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
814814
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
815815
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #dpas0>
816816

817-
// CHECK-COUNT-2: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
817+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32>
818818
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #dpas0>
819819

820820
tt.return
@@ -964,7 +964,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
964964
%a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a>
965965
%b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b>
966966

967-
// CHECK-COUNT-128: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
967+
// CHECK-COUNT-128: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32>
968968
%28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #dpas>
969969
%38 = ttg.convert_layout %28 : tensor<128x256xf32, #dpas> -> tensor<128x256xf32, #blocked>
970970

@@ -991,7 +991,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
991991
%a_mat = ttg.local_load %a : !ttg.memdesc<32x64xf16, #shared0, #smem> -> tensor<32x64xf16, #dot_operand_a>
992992
%b_mat = ttg.local_load %b : !ttg.memdesc<64x64xf16, #shared1, #smem> -> tensor<64x64xf16, #dot_operand_b>
993993

994-
// CHECK-COUNT-16: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
994+
// CHECK-COUNT-16: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32>
995995
%28 = tt.dot %a_mat, %b_mat, %cst : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #dpas>
996996
%38 = ttg.convert_layout %28 : tensor<32x64xf32, #dpas> -> tensor<32x64xf32, #blocked>
997997
%30 = tt.splat %ptr : !tt.ptr<f32> -> tensor<32x1x!tt.ptr<f32>, #blocked>
@@ -1040,7 +1040,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10401040
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
10411041
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
10421042

1043-
// CHECK-COUNT-2: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv8_fS_S_(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
1043+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_fS_S_i(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xf32>, vector<8xf32>, vector<8xf32>, i32) -> vector<8xf32>
10441044
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
10451045
%38 = ttg.convert_layout %28 : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked>
10461046

0 commit comments

Comments
 (0)