Skip to content

Commit 6c7a802

Browse files
[MatmulLoopPipeline] Predicate PrefetchOp
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 21e483c commit 6c7a802

File tree

3 files changed

+31
-29
lines changed

3 files changed

+31
-29
lines changed

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,17 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
7979
%51 = arith.muli %arg7, %c32_i32 : i32
8080
%52 = tt.splat %51 : i32 -> tensor<32x256xi32, #blocked1>
8181
// COM: There are 3 stages in loop pipelining, the first 2 prefetching stages are before the loop and the last one is inside the loop.
82+
// CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}}
83+
// CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]>
8284
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
83-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
85+
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
8486
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
85-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
87+
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
8688
// CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>) : i32 {
8789
// CHECK: %[[VAL_106:.*]] = tt.addptr %[[VAL_94]], {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi32, #[[$BLOCK_0]]>
8890
// CHECK: %[[VAL_107:.*]] = tt.addptr %[[VAL_95]], {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi32, #[[$BLOCK_1]]>
89-
// CHECK: triton_intel_gpu.prefetch %[[VAL_106]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
90-
// CHECK: triton_intel_gpu.prefetch %[[VAL_107]] {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
91+
// CHECK: triton_intel_gpu.prefetch %[[VAL_106]], {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
92+
// CHECK: triton_intel_gpu.prefetch %[[VAL_107]], {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
9193
// CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
9294
// CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_97]], {{.*}}, {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
9395
// CHECK: %[[VAL_121:.*]] = ttg.convert_layout %[[VAL_116]] : tensor<64x32xf16, #[[$BLOCK_0]]> -> tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>
@@ -166,12 +168,12 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
166168
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dot1>>
167169

168170
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
169-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
171+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
170172
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
171-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
173+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
172174
// CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)
173175
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
174-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
176+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
175177
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
176178
// CHECK-NEXT: scf.yield
177179
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
@@ -239,12 +241,12 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
239241
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dot1>>
240242

241243
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
242-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
244+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
243245
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
244-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
246+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
245247
// CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)
246248
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
247-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
249+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
248250
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
249251
// CHECK-NEXT: scf.yield
250252
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
@@ -302,18 +304,18 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup
302304
%12 = arith.extsi %arg3 : i32 to i64
303305
// CHECK: scf.for %[[OUTER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (i32)
304306
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
305-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR1]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
307+
// CHECK: triton_intel_gpu.prefetch [[PTR1]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
306308
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
307-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR2]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
309+
// CHECK: triton_intel_gpu.prefetch [[PTR2]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
308310
// CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
309-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR3]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
311+
// CHECK: triton_intel_gpu.prefetch [[PTR3]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
310312
// CHECK: [[PTR4:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
311-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR4]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
313+
// CHECK: triton_intel_gpu.prefetch [[PTR4]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
312314
// CHECK-NEXT: scf.for %[[INNER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x128xf32, #blocked>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>)
313315
// CHECK: [[PTR5:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
314-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR5]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
316+
// CHECK: triton_intel_gpu.prefetch [[PTR5]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
315317
// CHECK: [[PTR6:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
316-
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR6]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
318+
// CHECK: triton_intel_gpu.prefetch [[PTR6]], {{.*}} {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
317319
%13 = scf.for %arg6 = %0 to %7 step %c448_i32 iter_args(%arg7 = %10) -> (i32) : i32 {
318320
%14 = arith.divsi %arg6, %11 : i32
319321
%15 = arith.muli %14, %c8_i32 : i32

test/TritonIntelGPU/split-barrier.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
2626
// WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
2727
// SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Subgroup> <Subgroup> <None>
2828
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
29-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
29+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
3030
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
3131
// WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
3232
// SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Subgroup> <Subgroup> <None>
@@ -73,7 +73,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
7373
// WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
7474
// SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Subgroup> <Subgroup> <None>
7575
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
76-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
76+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
7777
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
7878
// WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
7979
// SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Subgroup> <Subgroup> <None>

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Interfaces/SideEffectInterfaces.h"
77
#include "triton/Analysis/AxisInfo.h"
88
#include "triton/Dialect/Triton/IR/Dialect.h"
9+
#include "llvm/ADT/TypeSwitch.h"
910
#include "llvm/Support/Casting.h"
1011
#include "llvm/Support/Debug.h"
1112

@@ -154,7 +155,7 @@ static Value getPredMask(RewriterBase &rewriter, Type typeLike,
154155
Value currentMask, Value pred) {
155156
Location loc = pred.getLoc();
156157
Value mask = pred;
157-
Type maskType = tt::getI1SameShape(typeLike);
158+
Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike));
158159

159160
if (isa<RankedTensorType>(maskType))
160161
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
@@ -167,18 +168,17 @@ static Value getPredMask(RewriterBase &rewriter, Type typeLike,
167168
static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
168169
Value pred) {
169170
OpBuilder::InsertionGuard guard(rewriter);
170-
if (mlir::isMemoryEffectFree(op) || isa<ttgi::PrefetchOp>(op))
171+
if (mlir::isMemoryEffectFree(op))
171172
return op;
172173

173-
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
174-
rewriter.setInsertionPoint(loadOp);
175-
Value mask = getPredMask(rewriter, loadOp.getPtr().getType(),
176-
loadOp.getMask(), pred);
177-
loadOp.getMaskMutable().assign(mask);
178-
return loadOp;
179-
}
180-
181-
llvm_unreachable("don't know how to predicate this operation");
174+
return TypeSwitch<Operation *, Operation *>(op)
175+
.Case<tt::LoadOp, ttgi::PrefetchOp>([&](auto op) {
176+
rewriter.setInsertionPoint(op);
177+
Value mask =
178+
getPredMask(rewriter, op.getPtr().getType(), op.getMask(), pred);
179+
op.getMaskMutable().assign(mask);
180+
return op;
181+
});
182182
}
183183

184184
/// Helper to get the defining operation of a value.

0 commit comments

Comments
 (0)