Skip to content

Commit 8aba8cb

Browse files
[MatmulLoopPipeline] Predicate PrefetchOp (#4016)
Now that `PrefetchOp` takes `mask` as an argument, we can handle predication of `PrefetchOp` in `MatmulLoopPipeline`. Benchmark CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/14657236822 (No performance regressions.) --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent b768f79 commit 8aba8cb

File tree

3 files changed

+38
-33
lines changed

3 files changed

+38
-33
lines changed

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,19 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
7979
%51 = arith.muli %arg7, %c32_i32 : i32
8080
%52 = tt.splat %51 : i32 -> tensor<32x256xi32, #blocked1>
8181
// COM: There are 3 stages in loop pipelining, the first 2 prefetching stages are before the loop and the last one is inside the loop.
82+
// CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}}
83+
// CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]>
84+
// CHECK: triton_intel_gpu.prefetch {{.*}}, %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
85+
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
8286
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
83-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
84-
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
85-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
87+
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
8688
// CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>) : i32 {
89+
// CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}}
8790
// CHECK: %[[VAL_106:.*]] = tt.addptr %[[VAL_94]], {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi32, #[[$BLOCK_0]]>
8891
// CHECK: %[[VAL_107:.*]] = tt.addptr %[[VAL_95]], {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi32, #[[$BLOCK_1]]>
89-
// CHECK: triton_intel_gpu.prefetch %[[VAL_106]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
90-
// CHECK: triton_intel_gpu.prefetch %[[VAL_107]] {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
92+
// CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]>
93+
// CHECK: triton_intel_gpu.prefetch %[[VAL_106]], %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
94+
// CHECK: triton_intel_gpu.prefetch %[[VAL_107]], {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
9195
// CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
9296
// CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_97]], {{.*}}, {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
9397
// CHECK: %[[VAL_121:.*]] = ttg.convert_layout %[[VAL_116]] : tensor<64x32xf16, #[[$BLOCK_0]]> -> tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>
@@ -166,12 +170,12 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
166170
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dot1>>
167171

168172
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
169-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
173+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
170174
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
171-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
175+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
172176
// CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)
173177
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
174-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
178+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
175179
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
176180
// CHECK-NEXT: scf.yield
177181
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
@@ -239,12 +243,12 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
239243
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dot1>>
240244

241245
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
242-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
246+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
243247
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
244-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
248+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
245249
// CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)
246250
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
247-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
251+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
248252
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
249253
// CHECK-NEXT: scf.yield
250254
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
@@ -302,18 +306,18 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup
302306
%12 = arith.extsi %arg3 : i32 to i64
303307
// CHECK: scf.for %[[OUTER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (i32)
304308
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
305-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR1]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
309+
// CHECK: triton_intel_gpu.prefetch [[PTR1]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
306310
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
307-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR2]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
311+
// CHECK: triton_intel_gpu.prefetch [[PTR2]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
308312
// CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
309-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR3]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
313+
// CHECK: triton_intel_gpu.prefetch [[PTR3]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
310314
// CHECK: [[PTR4:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
311-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR4]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
315+
// CHECK: triton_intel_gpu.prefetch [[PTR4]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
312316
// CHECK-NEXT: scf.for %[[INNER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x128xf32, #blocked>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>)
313317
// CHECK: [[PTR5:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
314-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR5]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
318+
// CHECK: triton_intel_gpu.prefetch [[PTR5]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
315319
// CHECK: [[PTR6:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
316-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR6]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
320+
// CHECK: triton_intel_gpu.prefetch [[PTR6]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
317321
%13 = scf.for %arg6 = %0 to %7 step %c448_i32 iter_args(%arg7 = %10) -> (i32) : i32 {
318322
%14 = arith.divsi %arg6, %11 : i32
319323
%15 = arith.muli %14, %c8_i32 : i32

test/TritonIntelGPU/split-barrier.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
2626
// WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
2727
// SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Subgroup> <Subgroup> <None>
2828
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
29-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
29+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
3030
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
3131
// WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
3232
// SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Subgroup> <Subgroup> <None>
@@ -73,7 +73,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
7373
// WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
7474
// SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Subgroup> <Subgroup> <None>
7575
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
76-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
76+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
7777
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
7878
// WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
7979
// SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Subgroup> <Subgroup> <None>

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Interfaces/SideEffectInterfaces.h"
77
#include "triton/Analysis/AxisInfo.h"
88
#include "triton/Dialect/Triton/IR/Dialect.h"
9+
#include "llvm/ADT/TypeSwitch.h"
910
#include "llvm/Support/Casting.h"
1011
#include "llvm/Support/Debug.h"
1112

@@ -149,12 +150,13 @@ static void collectOpsToPipeline(scf::ForOp forOp,
149150
}
150151
}
151152

152-
/// Combine the current mask with the given predicate.
153-
static Value getPredMask(RewriterBase &rewriter, Type typeLike,
154-
Value currentMask, Value pred) {
153+
/// Return a new mask of type of shape \p typeLike, and value combining the
154+
/// current mask \p currentMask with the given predicate \p pred.
155+
static Value computeNewMask(RewriterBase &rewriter, Type typeLike,
156+
Value currentMask, Value pred) {
155157
Location loc = pred.getLoc();
156158
Value mask = pred;
157-
Type maskType = tt::getI1SameShape(typeLike);
159+
Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike));
158160

159161
if (isa<RankedTensorType>(maskType))
160162
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
@@ -167,18 +169,17 @@ static Value getPredMask(RewriterBase &rewriter, Type typeLike,
167169
static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
168170
Value pred) {
169171
OpBuilder::InsertionGuard guard(rewriter);
170-
if (mlir::isMemoryEffectFree(op) || isa<ttgi::PrefetchOp>(op))
172+
if (mlir::isMemoryEffectFree(op))
171173
return op;
172174

173-
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
174-
rewriter.setInsertionPoint(loadOp);
175-
Value mask = getPredMask(rewriter, loadOp.getPtr().getType(),
176-
loadOp.getMask(), pred);
177-
loadOp.getMaskMutable().assign(mask);
178-
return loadOp;
179-
}
180-
181-
llvm_unreachable("don't know how to predicate this operation");
175+
return TypeSwitch<Operation *, Operation *>(op)
176+
.Case<tt::LoadOp, ttgi::PrefetchOp>([&](auto op) {
177+
rewriter.setInsertionPoint(op);
178+
Value mask =
179+
computeNewMask(rewriter, op.getPtr().getType(), op.getMask(), pred);
180+
op.getMaskMutable().assign(mask);
181+
return op;
182+
});
182183
}
183184

184185
/// Helper to get the defining operation of a value.

0 commit comments

Comments
 (0)