@@ -676,6 +676,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
676
676
// CHECK-NEXT: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
677
677
// CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement [[BCAST0]], [[VEC1]][[[CST_0]] : i32] : vector<1xf32>
678
678
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to i32
679
+ // CHECK-NEXT: [[TRUE1:%.*]] = llvm.mlir.constant(true) : i1
679
680
// CHECK-NEXT: [[AND1:%.*]] = llvm.and {{.*}}, [[ARG2_0]] : i1
680
681
// CHECK-NEXT: [[VEC2:%.*]] = llvm.mlir.undef : vector<1xi32>
681
682
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -1059,17 +1060,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1059
1060
module attributes {" ttg.target" = " xpu" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
1060
1061
// CHECK-LABEL: atomic_cas_f32_scalar_no_store
1061
1062
tt.func @atomic_cas_f32_scalar_no_store (%ptr : !tt.ptr <f32 >, %cmp : f32 , %val : f32 ) {
1062
- // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
1063
- // CHECK: [[CMP0:%.*]] = llvm.icmp "eq"
1064
- // CHECK: [[MASK0:%.*]] = llvm.and [[TRUE]], [[CMP0]]
1065
- // CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1066
- // CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]]
1067
- // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1063
+ // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1064
+ // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(-1 : i32) : i32
1065
+ // CHECK: [[MASKLANE:%.*]] = llvm.and
1066
+ // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1067
+ // CHECK: [[MASKWARP:%.*]] = llvm.and
1068
+ // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1069
+ // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1070
+ // CHECK: llvm.mlir.constant(-1 : i32) : i32
1071
+ // CHECK: [[MASKBLOCK:%.*]] = llvm.and
1072
+ // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1073
+ // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1074
+ // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
1068
1075
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1069
1076
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1070
1077
// CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
1071
1078
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
1072
- // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO ]] : i32)
1079
+ // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO1 ]] : i32)
1073
1080
// CHECK-NEXT: ^bb1:
1074
1081
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32
1075
1082
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast %arg2 : f32 to i32
@@ -1089,13 +1096,19 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
1089
1096
// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
1090
1097
// CHECK-LABEL: atomic_cas_f32_scalar
1091
1098
tt.func @atomic_cas_f32_scalar (%ptr : !tt.ptr <f32 >, %cmp : f32 , %val : f32 ) {
1092
- // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
1093
- // CHECK: [[CMP0:%.*]] = llvm.icmp "eq"
1094
- // CHECK: [[MASK0:%.*]] = llvm.and [[TRUE]], [[CMP0]]
1095
- // CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1096
- // CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]]
1097
- // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1098
- // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO]] : i32)
1099
+ // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1100
+ // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(-1 : i32) : i32
1101
+ // CHECK: [[MASKLANE:%.*]] = llvm.and
1102
+ // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1103
+ // CHECK: [[MASKWARP:%.*]] = llvm.and
1104
+ // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1105
+ // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1106
+ // CHECK: llvm.mlir.constant(-1 : i32) : i32
1107
+ // CHECK: [[MASKBLOCK:%.*]] = llvm.and
1108
+ // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1109
+ // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1110
+ // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
1111
+ // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO1]] : i32)
1099
1112
// CHECK-NEXT: ^bb1:
1100
1113
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32
1101
1114
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast %arg2 : f32 to i32
@@ -1131,14 +1144,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1131
1144
// CHECK-NEXT: [[EV1_ARG2:%.*]] = llvm.extractvalue %arg2[1] : !llvm.struct<(f32, f32)>
1132
1145
// CHECK: [[EV0_ARG0:%.*]] = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr<1>, ptr<1>)>
1133
1146
// CHECK-NEXT: [[EV1_ARG0:%.*]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<1>, ptr<1>)>
1134
- // CHECK: llvm.mlir.constant(true) : i1
1135
- // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1136
- // CHECK: [[PRED0:%.*]] = llvm.and [[CST_TRUE]], {{.*}} : i1
1137
- // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1147
+ // CHECK: [[EV0_ARG1:%.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(i1, i1)>
1148
+ // CHECK-NEXT: [[EV1_ARG1:%.*]] = llvm.extractvalue %arg1[1] : !llvm.struct<(i1, i1)>
1149
+ // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1138
1150
// CHECK: [[IE1:%.*]] = llvm.insertelement [[EV0_ARG2]], [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1139
- // CHECK-NEXT: [[PRED1:%.*]] = llvm.and [[PRED0]], {{.*}} : i1
1140
1151
// CHECK-NEXT: [[ZERO1:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
1141
- // CHECK: llvm.cond_br [[PRED1 ]], ^bb1, ^bb2([[ZERO1]] : f32)
1152
+ // CHECK: llvm.cond_br [[EV0_ARG1 ]], ^bb1, ^bb2([[ZERO1]] : f32)
1142
1153
// CHECK-NEXT: ^bb1:
1143
1154
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to f32
1144
1155
// CHECK-NEXT: [[RMW_RES1:%.*]] = llvm.atomicrmw fadd [[EV0_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32
@@ -1147,13 +1158,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1147
1158
// CHECK-NEXT: [[RMW_CAST:%.*]] = llvm.bitcast [[RMW_PHI1]] : f32 to f32
1148
1159
// CHECK-NEXT: [[UNDEF2:%.*]] = llvm.mlir.undef : vector<1xf32>
1149
1160
// CHECK: [[IE2:%.*]] = llvm.insertelement [[EV1_ARG2]], [[UNDEF2]][{{.*}} : i64] : vector<1xf32>
1150
- // CHECK-NEXT: [[PRED2:%.*]] = llvm.and [[PRED0]], {{.*}} : i1
1151
1161
// CHECK-NEXT: [[ZERO2:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
1152
1162
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1153
1163
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1154
1164
// CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
1155
1165
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
1156
- // CHECK-NEXT: llvm.cond_br [[PRED2 ]], ^bb3, ^bb4([[ZERO2]] : f32)
1166
+ // CHECK-NEXT: llvm.cond_br [[EV1_ARG1 ]], ^bb3, ^bb4([[ZERO2]] : f32)
1157
1167
// CHECK-NEXT: ^bb3:
1158
1168
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE2]] : vector<1xf32> to f32
1159
1169
// CHECK-NEXT: [[RMW_RES2:%.*]] = llvm.atomicrmw fadd [[EV1_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32
@@ -1169,14 +1179,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1169
1179
module attributes {" ttg.target" = " xpu" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
1170
1180
// CHECK-LABEL: atomic_add_f32_scalar_no_store
1171
1181
tt.func @atomic_add_f32_scalar_no_store (%arg0 : !tt.ptr <f32 >, %arg1 : i1 , %arg2 : f32 ) {
1172
- // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1173
- // CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1174
- // CHECK-NEXT: [[AND:%.*]] = llvm.and [[CST_TRUE]], [[CMP]] : i1
1175
- // CHECK: [[CMP1:%.*]] = llvm.icmp "eq"
1176
- // CHECK-NEXT: [[AND1:%.*]] = llvm.and [[AND]], [[CMP1]] : i1
1177
- // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1182
+ // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1183
+ // CHECK: [[MASKLANE:%.*]] = llvm.and
1184
+ // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1185
+ // CHECK: [[MASKWARP:%.*]] = llvm.and
1186
+ // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1187
+ // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1188
+ // CHECK: llvm.mlir.constant(-1 : i32) : i32
1189
+ // CHECK: [[MASKBLOCK:%.*]] = llvm.and
1190
+ // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1191
+ // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1192
+ // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1178
1193
// CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1179
- // CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1
1194
+ // CHECK: [[PRED:%.*]] = llvm.and %arg1, [[MASK]] : i1
1180
1195
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
1181
1196
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1182
1197
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
@@ -1200,14 +1215,19 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
1200
1215
// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
1201
1216
// CHECK-LABEL: atomic_add_f32_scalar
1202
1217
tt.func @atomic_add_f32_scalar (%arg0 : !tt.ptr <f32 >, %arg1 : i1 , %arg2 : f32 ) {
1203
- // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1204
- // CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1205
- // CHECK-NEXT: [[AND:%.*]] = llvm.and [[CST_TRUE]], [[CMP]] : i1
1206
- // CHECK: [[CMP1:%.*]] = llvm.icmp "eq"
1207
- // CHECK-NEXT: [[AND1:%.*]] = llvm.and [[AND]], [[CMP1]] : i1
1208
- // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1218
+ // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1219
+ // CHECK: [[MASKLANE:%.*]] = llvm.and
1220
+ // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1221
+ // CHECK: [[MASKWARP:%.*]] = llvm.and
1222
+ // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1223
+ // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1224
+ // CHECK: llvm.mlir.constant(-1 : i32) : i32
1225
+ // CHECK: [[MASKBLOCK:%.*]] = llvm.and
1226
+ // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1227
+ // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1228
+ // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1209
1229
// CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1210
- // CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1
1230
+ // CHECK: [[PRED:%.*]] = llvm.and %arg1, [[MASK]] : i1
1211
1231
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
1212
1232
// CHECK-NEXT: llvm.cond_br [[PRED]], ^bb1, ^bb2([[ZERO]] : f32)
1213
1233
// CHECK-NEXT: ^bb1:
@@ -1295,22 +1315,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1295
1315
// CHECK-NEXT: [[ARG0_1:%.*]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<1>, ptr<1>)>
1296
1316
// CHECK-NEXT: [[ARG1_0:%.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(f32, f32)>
1297
1317
// CHECK-NEXT: [[ARG1_1:%.*]] = llvm.extractvalue %arg1[1] : !llvm.struct<(f32, f32)>
1298
- // CHECK: llvm.mlir.constant(true) : i1
1299
1318
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1300
- // CHECK-NEXT: llvm.call spir_funccc @_Z12get_local_idj([[ZERO]]) {{.*}} : (i32) -> i64
1301
- // CHECK: [[TRUE1:%.*]] = llvm.mlir.constant(true) : i1
1302
- // CHECK: [[TRUE2:%.*]] = llvm.mlir.constant(true) : i1
1303
- // CHECK: [[PRED:%.*]] = llvm.and [[TRUE1]], [[TRUE2]] : i1
1319
+ // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
1320
+ // CHECK-NEXT: llvm.call spir_funccc @_Z12get_local_idj([[ZERO1]]) {{.*}} : (i32) -> i64
1321
+ // CHECK: [[PRED:%.*]] = llvm.mlir.constant(true) : i1
1304
1322
// CHECK: llvm.cond_br [[PRED]], ^bb1, ^bb2
1305
1323
// CHECK-NEXT: ^bb1:
1306
1324
// CHECK-NEXT: [[BCAST:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
1307
1325
// CHECK-NEXT: llvm.store {{.*}}, [[BCAST]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
1308
1326
// CHECK-NEXT: llvm.br ^bb2
1309
1327
// CHECK-NEXT: ^bb2:
1328
+ // CHECK: llvm.mlir.undef : vector<1xf32>
1329
+ // CHECK: [[PRED2:%.*]] = llvm.mlir.constant(true) : i1
1310
1330
// CHECK: [[VEC:%.*]] = llvm.mlir.undef : vector<1xi32>
1311
1331
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1312
1332
// CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement {{.*}}, [[VEC]][[[ZERO]] : i32] : vector<1xi32>
1313
- // CHECK: llvm.cond_br [[PRED ]], ^bb3, ^bb4
1333
+ // CHECK: llvm.cond_br [[PRED2 ]], ^bb3, ^bb4
1314
1334
// CHECK-NEXT: ^bb3:
1315
1335
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
1316
1336
// CHECK-NEXT: llvm.store [[IE1]], [[BCAST1]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
0 commit comments