Skip to content

Commit c9022fc

Browse files
alexbadenchengjunlu
authored andcommitted
fixup lit tests
1 parent 503bcbe commit c9022fc

File tree

1 file changed

+62
-42
lines changed

1 file changed

+62
-42
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
676676
// CHECK-NEXT: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
677677
// CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement [[BCAST0]], [[VEC1]][[[CST_0]] : i32] : vector<1xf32>
678678
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to i32
679+
// CHECK-NEXT: [[TRUE1:%.*]] = llvm.mlir.constant(true) : i1
679680
// CHECK-NEXT: [[AND1:%.*]] = llvm.and {{.*}}, [[ARG2_0]] : i1
680681
// CHECK-NEXT: [[VEC2:%.*]] = llvm.mlir.undef : vector<1xi32>
681682
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -1059,17 +1060,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10591060
module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10601061
// CHECK-LABEL: atomic_cas_f32_scalar_no_store
10611062
tt.func @atomic_cas_f32_scalar_no_store(%ptr : !tt.ptr<f32>, %cmp : f32, %val : f32) {
1062-
// CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
1063-
// CHECK: [[CMP0:%.*]] = llvm.icmp "eq"
1064-
// CHECK: [[MASK0:%.*]] = llvm.and [[TRUE]], [[CMP0]]
1065-
// CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1066-
// CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]]
1067-
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1063+
// CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1064+
// CHECK: [[TRUE:%.*]] = llvm.mlir.constant(-1 : i32) : i32
1065+
// CHECK: [[MASKLANE:%.*]] = llvm.and
1066+
// CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1067+
// CHECK: [[MASKWARP:%.*]] = llvm.and
1068+
// CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1069+
// CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1070+
// CHECK: llvm.mlir.constant(-1 : i32) : i32
1071+
// CHECK: [[MASKBLOCK:%.*]] = llvm.and
1072+
// CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1073+
// CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1074+
// CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
10681075
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
10691076
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
10701077
// CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
10711078
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
1072-
// CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO]] : i32)
1079+
// CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO1]] : i32)
10731080
// CHECK-NEXT: ^bb1:
10741081
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32
10751082
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast %arg2 : f32 to i32
@@ -1089,13 +1096,19 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
10891096
// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
10901097
// CHECK-LABEL: atomic_cas_f32_scalar
10911098
tt.func @atomic_cas_f32_scalar(%ptr : !tt.ptr<f32>, %cmp : f32, %val : f32) {
1092-
// CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
1093-
// CHECK: [[CMP0:%.*]] = llvm.icmp "eq"
1094-
// CHECK: [[MASK0:%.*]] = llvm.and [[TRUE]], [[CMP0]]
1095-
// CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1096-
// CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]]
1097-
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1098-
// CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO]] : i32)
1099+
// CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1100+
// CHECK: [[TRUE:%.*]] = llvm.mlir.constant(-1 : i32) : i32
1101+
// CHECK: [[MASKLANE:%.*]] = llvm.and
1102+
// CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1103+
// CHECK: [[MASKWARP:%.*]] = llvm.and
1104+
// CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1105+
// CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1106+
// CHECK: llvm.mlir.constant(-1 : i32) : i32
1107+
// CHECK: [[MASKBLOCK:%.*]] = llvm.and
1108+
// CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1109+
// CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1110+
// CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
1111+
// CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO1]] : i32)
10991112
// CHECK-NEXT: ^bb1:
11001113
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32
11011114
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast %arg2 : f32 to i32
@@ -1131,14 +1144,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11311144
// CHECK-NEXT: [[EV1_ARG2:%.*]] = llvm.extractvalue %arg2[1] : !llvm.struct<(f32, f32)>
11321145
// CHECK: [[EV0_ARG0:%.*]] = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr<1>, ptr<1>)>
11331146
// CHECK-NEXT: [[EV1_ARG0:%.*]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<1>, ptr<1>)>
1134-
// CHECK: llvm.mlir.constant(true) : i1
1135-
// CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1136-
// CHECK: [[PRED0:%.*]] = llvm.and [[CST_TRUE]], {{.*}} : i1
1137-
// CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1147+
// CHECK: [[EV0_ARG1:%.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(i1, i1)>
1148+
// CHECK-NEXT: [[EV1_ARG1:%.*]] = llvm.extractvalue %arg1[1] : !llvm.struct<(i1, i1)>
1149+
// CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
11381150
// CHECK: [[IE1:%.*]] = llvm.insertelement [[EV0_ARG2]], [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1139-
// CHECK-NEXT: [[PRED1:%.*]] = llvm.and [[PRED0]], {{.*}} : i1
11401151
// CHECK-NEXT: [[ZERO1:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
1141-
// CHECK: llvm.cond_br [[PRED1]], ^bb1, ^bb2([[ZERO1]] : f32)
1152+
// CHECK: llvm.cond_br [[EV0_ARG1]], ^bb1, ^bb2([[ZERO1]] : f32)
11421153
// CHECK-NEXT: ^bb1:
11431154
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to f32
11441155
// CHECK-NEXT: [[RMW_RES1:%.*]] = llvm.atomicrmw fadd [[EV0_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32
@@ -1147,13 +1158,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11471158
// CHECK-NEXT: [[RMW_CAST:%.*]] = llvm.bitcast [[RMW_PHI1]] : f32 to f32
11481159
// CHECK-NEXT: [[UNDEF2:%.*]] = llvm.mlir.undef : vector<1xf32>
11491160
// CHECK: [[IE2:%.*]] = llvm.insertelement [[EV1_ARG2]], [[UNDEF2]][{{.*}} : i64] : vector<1xf32>
1150-
// CHECK-NEXT: [[PRED2:%.*]] = llvm.and [[PRED0]], {{.*}} : i1
11511161
// CHECK-NEXT: [[ZERO2:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
11521162
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
11531163
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
11541164
// CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
11551165
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
1156-
// CHECK-NEXT: llvm.cond_br [[PRED2]], ^bb3, ^bb4([[ZERO2]] : f32)
1166+
// CHECK-NEXT: llvm.cond_br [[EV1_ARG1]], ^bb3, ^bb4([[ZERO2]] : f32)
11571167
// CHECK-NEXT: ^bb3:
11581168
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE2]] : vector<1xf32> to f32
11591169
// CHECK-NEXT: [[RMW_RES2:%.*]] = llvm.atomicrmw fadd [[EV1_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32
@@ -1169,14 +1179,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11691179
module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11701180
// CHECK-LABEL: atomic_add_f32_scalar_no_store
11711181
tt.func @atomic_add_f32_scalar_no_store(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
1172-
// CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1173-
// CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1174-
// CHECK-NEXT: [[AND:%.*]] = llvm.and [[CST_TRUE]], [[CMP]] : i1
1175-
// CHECK: [[CMP1:%.*]] = llvm.icmp "eq"
1176-
// CHECK-NEXT: [[AND1:%.*]] = llvm.and [[AND]], [[CMP1]] : i1
1177-
// CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1182+
// CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1183+
// CHECK: [[MASKLANE:%.*]] = llvm.and
1184+
// CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1185+
// CHECK: [[MASKWARP:%.*]] = llvm.and
1186+
// CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1187+
// CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1188+
// CHECK: llvm.mlir.constant(-1 : i32) : i32
1189+
// CHECK: [[MASKBLOCK:%.*]] = llvm.and
1190+
// CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1191+
// CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1192+
// CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
11781193
// CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1179-
// CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1
1194+
// CHECK: [[PRED:%.*]] = llvm.and %arg1, [[MASK]] : i1
11801195
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
11811196
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
11821197
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
@@ -1200,14 +1215,19 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
12001215
// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
12011216
// CHECK-LABEL: atomic_add_f32_scalar
12021217
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
1203-
// CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1204-
// CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1205-
// CHECK-NEXT: [[AND:%.*]] = llvm.and [[CST_TRUE]], [[CMP]] : i1
1206-
// CHECK: [[CMP1:%.*]] = llvm.icmp "eq"
1207-
// CHECK-NEXT: [[AND1:%.*]] = llvm.and [[AND]], [[CMP1]] : i1
1208-
// CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1218+
// CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1219+
// CHECK: [[MASKLANE:%.*]] = llvm.and
1220+
// CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1221+
// CHECK: [[MASKWARP:%.*]] = llvm.and
1222+
// CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1223+
// CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1224+
// CHECK: llvm.mlir.constant(-1 : i32) : i32
1225+
// CHECK: [[MASKBLOCK:%.*]] = llvm.and
1226+
// CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1227+
// CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1228+
// CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
12091229
// CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1210-
// CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1
1230+
// CHECK: [[PRED:%.*]] = llvm.and %arg1, [[MASK]] : i1
12111231
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
12121232
// CHECK-NEXT: llvm.cond_br [[PRED]], ^bb1, ^bb2([[ZERO]] : f32)
12131233
// CHECK-NEXT: ^bb1:
@@ -1295,22 +1315,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12951315
// CHECK-NEXT: [[ARG0_1:%.*]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<1>, ptr<1>)>
12961316
// CHECK-NEXT: [[ARG1_0:%.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(f32, f32)>
12971317
// CHECK-NEXT: [[ARG1_1:%.*]] = llvm.extractvalue %arg1[1] : !llvm.struct<(f32, f32)>
1298-
// CHECK: llvm.mlir.constant(true) : i1
12991318
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1300-
// CHECK-NEXT: llvm.call spir_funccc @_Z12get_local_idj([[ZERO]]) {{.*}} : (i32) -> i64
1301-
// CHECK: [[TRUE1:%.*]] = llvm.mlir.constant(true) : i1
1302-
// CHECK: [[TRUE2:%.*]] = llvm.mlir.constant(true) : i1
1303-
// CHECK: [[PRED:%.*]] = llvm.and [[TRUE1]], [[TRUE2]] : i1
1319+
// CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
1320+
// CHECK-NEXT: llvm.call spir_funccc @_Z12get_local_idj([[ZERO1]]) {{.*}} : (i32) -> i64
1321+
// CHECK: [[PRED:%.*]] = llvm.mlir.constant(true) : i1
13041322
// CHECK: llvm.cond_br [[PRED]], ^bb1, ^bb2
13051323
// CHECK-NEXT: ^bb1:
13061324
// CHECK-NEXT: [[BCAST:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
13071325
// CHECK-NEXT: llvm.store {{.*}}, [[BCAST]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
13081326
// CHECK-NEXT: llvm.br ^bb2
13091327
// CHECK-NEXT: ^bb2:
1328+
// CHECK: llvm.mlir.undef : vector<1xf32>
1329+
// CHECK: [[PRED2:%.*]] = llvm.mlir.constant(true) : i1
13101330
// CHECK: [[VEC:%.*]] = llvm.mlir.undef : vector<1xi32>
13111331
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
13121332
// CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement {{.*}}, [[VEC]][[[ZERO]] : i32] : vector<1xi32>
1313-
// CHECK: llvm.cond_br [[PRED]], ^bb3, ^bb4
1333+
// CHECK: llvm.cond_br [[PRED2]], ^bb3, ^bb4
13141334
// CHECK-NEXT: ^bb3:
13151335
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
13161336
// CHECK-NEXT: llvm.store [[IE1]], [[BCAST1]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>

0 commit comments

Comments
 (0)