Skip to content

Commit cc4150d

Browse files
address review comments
Signed-off-by: Whitney Tsang <[email protected]>
1 parent ed31253 commit cc4150d

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
8181
// COM: There are 3 stages in loop pipelining, the first 2 prefetching stages are before the loop and the last one is inside the loop.
8282
// CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}}
8383
// CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]>
84-
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
84+
// CHECK: triton_intel_gpu.prefetch {{.*}}, %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
8585
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
8686
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
8787
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,10 @@ static void collectOpsToPipeline(scf::ForOp forOp,
150150
}
151151
}
152152

153-
/// Combine the current mask with the given predicate.
154-
static Value getPredMask(RewriterBase &rewriter, Type typeLike,
155-
Value currentMask, Value pred) {
153+
/// Return a new mask of type of shape \p typeLike, and value combining the
154+
/// current mask \p currentMask with the given predicate \p pred.
155+
static Value computeNewMask(RewriterBase &rewriter, Type typeLike,
156+
Value currentMask, Value pred) {
156157
Location loc = pred.getLoc();
157158
Value mask = pred;
158159
Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike));
@@ -175,7 +176,7 @@ static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
175176
.Case<tt::LoadOp, ttgi::PrefetchOp>([&](auto op) {
176177
rewriter.setInsertionPoint(op);
177178
Value mask =
178-
getPredMask(rewriter, op.getPtr().getType(), op.getMask(), pred);
179+
computeNewMask(rewriter, op.getPtr().getType(), op.getMask(), pred);
179180
op.getMaskMutable().assign(mask);
180181
return op;
181182
});

0 commit comments

Comments
 (0)