Skip to content

Commit 935ded3

Browse files
authored
Fix regression issue in flex decoding. (#3999)
The `tt.store` operation with BlockPointer fallbacks to scatter store if the BLOCK shape or the value layout was not supported by the 2D BLOCK IO. The lowering code would transform the BlockPointer to the pointers and masks. The scatter store should apply the `and` to the `maskElems` if the `llMask` doesn't exsits. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 9ef82e2 commit 935ded3

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,20 +364,31 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
364364
%c1_i64 = arith.constant 1 : i64
365365
%c0_i32 = arith.constant 0 : i32
366366
%0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x16xf16, #blocked>>
367+
// CHECK: llvm.call spir_funccc @_Z12get_local_idj
367368
// CHECK-NOT: llvm.icmp "slt"
368-
// CHECK-COUNT-32: llvm.store
369+
// CHECK: %[[threadID:.*]] = llvm.call spir_funccc @_Z12get_local_idj
370+
// CHECK: %[[VAL_583:.*]] = llvm.trunc %[[threadID]] : i64 to i32
371+
// CHECK: %[[VAL_584:.*]] = llvm.mlir.constant(16 : i32) : i32
372+
// CHECK: %[[VAL_586:.*]] = llvm.udiv %[[VAL_583]], %[[VAL_584]] : i32
373+
// CHECK: %[[VAL_587:.*]] = llvm.mlir.constant(3 : i32) : i32
374+
// CHECK: %[[VAL_588:.*]] = llvm.and %[[VAL_586]], %[[VAL_587]] : i32
375+
// CHECK: %[[threadPred:.*]] = llvm.icmp "eq" %[[VAL_588]], {{.*}} : i32
376+
// CHECK-COUNT-32: llvm.cond_br %[[threadPred]]
369377
tt.store %0, %cst : !tt.ptr<tensor<64x16xf16, #blocked>>
370378

371379
// CHECK-COUNT-16: llvm.icmp "slt"
372-
// CHECK-COUNT-32: llvm.store
380+
// CHECK: %[[threadPred_0:.*]] = llvm.icmp "eq"
381+
// CHECK-COUNT-32: llvm.and %[[threadPred_0]], {{.*}} : i1
373382
tt.store %0, %cst {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<64x16xf16, #blocked>>
374383

375384
// CHECK-COUNT-16: llvm.icmp "slt"
376-
// CHECK-COUNT-32: llvm.store
385+
// CHECK: %[[threadPred_1:.*]] = llvm.icmp "eq"
386+
// CHECK-COUNT-32: llvm.and %[[threadPred_1]], {{.*}} : i1
377387
tt.store %0, %cst {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<64x16xf16, #blocked>>
378388

379389
// CHECK-COUNT-32: llvm.icmp "slt"
380-
// CHECK-COUNT-32: llvm.store
390+
// CHECK: %[[threadPred_2:.*]] = llvm.icmp "eq"
391+
// CHECK-COUNT-32: llvm.and %[[threadPred_2]], {{.*}} : i1
381392
tt.store %0, %cst {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x16xf16, #blocked>>
382393

383394
tt.return

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2232,7 +2232,7 @@ struct StoreOpConversion
22322232
}
22332233

22342234
Value maskVal = threadPred;
2235-
if (llMask) {
2235+
if (maskElems.size() > 0) {
22362236
auto mask = maskElems[vecStart];
22372237
maskVal = maybeAnd(rewriter, loc, threadPred, mask);
22382238
}

0 commit comments

Comments
 (0)