@@ -814,7 +814,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
814
814
%BB_DOT = ttg.local_load %BB : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_b >
815
815
%cst0 = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf32 , #dpas0 >
816
816
817
- // CHECK-COUNT-2: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
817
+ // CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
818
818
%D = tt.dot %AA_DOT , %BB_DOT , %cst0 : tensor <16 x16 xf16 , #dot_operand_a > * tensor <16 x16 xf16 , #dot_operand_b > -> tensor <16 x16 xf32 , #dpas0 >
819
819
820
820
tt.return
@@ -964,7 +964,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
964
964
%a_mat = ttg.local_load %a : !ttg.memdesc <128 x32 xf16 , #shared , #smem > -> tensor <128 x32 xf16 , #dot_operand_a >
965
965
%b_mat = ttg.local_load %b : !ttg.memdesc <32 x256 xf16 , #shared , #smem > -> tensor <32 x256 xf16 , #dot_operand_b >
966
966
967
- // CHECK-COUNT-128: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
967
+ // CHECK-COUNT-128: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
968
968
%28 = tt.dot %a_mat , %b_mat , %cst : tensor <128 x32 xf16 , #dot_operand_a > * tensor <32 x256 xf16 , #dot_operand_b > -> tensor <128 x256 xf32 , #dpas >
969
969
%38 = ttg.convert_layout %28 : tensor <128 x256 xf32 , #dpas > -> tensor <128 x256 xf32 , #blocked >
970
970
@@ -991,7 +991,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
991
991
%a_mat = ttg.local_load %a : !ttg.memdesc <32 x64 xf16 , #shared0 , #smem > -> tensor <32 x64 xf16 , #dot_operand_a >
992
992
%b_mat = ttg.local_load %b : !ttg.memdesc <64 x64 xf16 , #shared1 , #smem > -> tensor <64 x64 xf16 , #dot_operand_b >
993
993
994
- // CHECK-COUNT-16: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
994
+ // CHECK-COUNT-16: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
995
995
%28 = tt.dot %a_mat , %b_mat , %cst : tensor <32 x64 xf16 , #dot_operand_a > * tensor <64 x64 xf16 , #dot_operand_b > -> tensor <32 x64 xf32 , #dpas >
996
996
%38 = ttg.convert_layout %28 : tensor <32 x64 xf32 , #dpas > -> tensor <32 x64 xf32 , #blocked >
997
997
%30 = tt.splat %ptr : !tt.ptr <f32 > -> tensor <32 x1 x!tt.ptr <f32 >, #blocked >
@@ -1040,7 +1040,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1040
1040
%a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
1041
1041
%b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
1042
1042
1043
- // CHECK-COUNT-2: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv8_fS_S_ (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
1043
+ // CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_fS_S_i (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xf32>, vector<8xf32>, vector<8xf32>, i32 ) -> vector<8xf32>
1044
1044
%28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
1045
1045
%38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
1046
1046
0 commit comments