Skip to content

Commit f6b90ca

Browse files
committed
Support the tensor of pointer in the matmul loop pipelining.
1 parent da874b8 commit f6b90ca

File tree

3 files changed

+50
-20
lines changed

3 files changed

+50
-20
lines changed

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,32 +79,39 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
7979
%51 = arith.muli %arg7, %c32_i32 : i32
8080
%52 = tt.splat %51 : i32 -> tensor<32x256xi32, #blocked1>
8181
// COM: There are 3 stages in loop pipelining, the first 2 prefetching stages are before the loop and the last one is inside the loop.
82-
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
83-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
84-
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
85-
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
86-
// CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>) : i32 {
82+
// CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}} : tensor<1x32xi32, #[[$BLOCK_0]]>
83+
// CHECK: %[[LOAD_MASK_2D:.*]] = tt.broadcast %[[LOAD_MASK]] : tensor<1x32xi1, #[[$BLOCK_0]]> -> tensor<64x32xi1, #[[$BLOCK_0]]>
84+
// CHECK: %[[LOOP_MASK:.*]] = tt.splat {{.*}} : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]>
85+
// CHECK: %[[PREFETCH_MASK:.*]] = arith.andi %[[LOOP_MASK]], %[[LOAD_MASK_2D]] : tensor<64x32xi1, #[[$BLOCK_0]]>
86+
// CHECK: triton_intel_gpu.prefetch {{.*}}, %[[PREFETCH_MASK]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi1, #[[$BLOCK_0]]>
87+
// CHECK: %[[LOAD_MASK_2:.*]] = arith.cmpi slt, {{.*}} : tensor<32x1xi32, #[[$BLOCK_1]]>
88+
// CHECK: %[[LOAD_MASK_2D_2:.*]] = tt.broadcast %[[LOAD_MASK_2]] : tensor<32x1xi1, #[[$BLOCK_1]]> -> tensor<32x256xi1, #[[$BLOCK_1]]>
89+
// CHECK: %[[LOOP_MASK:.*]] = tt.splat {{.*}} : i1 -> tensor<32x256xi1, #[[$BLOCK_1]]>
90+
// CHECK: %[[PREFETCH_MASK:.*]] = arith.andi %[[LOOP_MASK]], %[[LOAD_MASK_2D_2]] : tensor<32x256xi1, #[[$BLOCK_1]]>
91+
// CHECK: triton_intel_gpu.prefetch {{.*}}, %[[PREFETCH_MASK]] {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi1, #[[$BLOCK_1]]>
92+
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi1, #[[$BLOCK_0]]>
93+
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi1, #[[$BLOCK_1]]>
94+
// CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}, %[[VAL_98:.*]] = {{.*}}, %[[VAL_99:.*]] = {{.*}}, %[[VAL_100:.*]] = {{.*}}, %[[VAL_101:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi1, #[[$BLOCK_0]]>, tensor<64x32xi1, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi1, #[[$BLOCK_1]]>, tensor<32x256xi1, #[[$BLOCK_1]]>) : i32 {
8795
// CHECK: %[[VAL_106:.*]] = tt.addptr %[[VAL_94]], {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi32, #[[$BLOCK_0]]>
8896
// CHECK: %[[VAL_107:.*]] = tt.addptr %[[VAL_95]], {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi32, #[[$BLOCK_1]]>
89-
// CHECK: triton_intel_gpu.prefetch %[[VAL_106]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
90-
// CHECK: triton_intel_gpu.prefetch %[[VAL_107]] {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
91-
// CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
92-
// CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_97]], {{.*}}, {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
97+
// CHECK: triton_intel_gpu.prefetch %[[VAL_106]], {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi1, #[[$BLOCK_0]]>
98+
// CHECK: triton_intel_gpu.prefetch %[[VAL_107]], {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi1, #[[$BLOCK_1]]>
99+
// CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} {triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
100+
// CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_99]], {{.*}}, {{.*}} {triton_intel_gpu.block_io = "row_major"} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
93101
// CHECK: %[[VAL_121:.*]] = ttg.convert_layout %[[VAL_116]] : tensor<64x32xf16, #[[$BLOCK_0]]> -> tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>
94102
// CHECK: %[[VAL_122:.*]] = ttg.convert_layout %[[VAL_120]] : tensor<32x256xf16, #[[$BLOCK_1]]> -> tensor<32x256xf16, #{{.*}}<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
95103
// CHECK: %[[VAL_123:.*]] = tt.dot %[[VAL_121]], %[[VAL_122]], %[[VAL_93]], inputPrecision = tf32 : tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #{{.*}}<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<64x256xf32, #[[$DPAS]]>
96-
// CHECK: scf.yield %[[VAL_123]], %[[VAL_106]], %[[VAL_107]], %[[VAL_94]], %[[VAL_95]] : tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
97104
%53:3 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %38, %arg12 = %48) -> (tensor<64x256xf32, #dpas>, tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<32x256x!tt.ptr<f16>, #blocked1>) : i32 {
98105
%72 = arith.muli %arg9, %c32_i32 : i32
99106
%73 = arith.subi %arg5, %72 : i32
100107
%74 = tt.splat %73 : i32 -> tensor<1x32xi32, #blocked>
101108
%75 = arith.cmpi slt, %33, %74 : tensor<1x32xi32, #blocked>
102109
%76 = tt.broadcast %75 : tensor<1x32xi1, #blocked> -> tensor<64x32xi1, #blocked>
103-
%77 = tt.load %arg11, %76, %cst_0 : tensor<64x32x!tt.ptr<f16>, #blocked>
110+
%77 = tt.load %arg11, %76, %cst_0 {triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #blocked>
104111
%78 = tt.splat %73 : i32 -> tensor<32x1xi32, #blocked1>
105112
%79 = arith.cmpi slt, %40, %78 : tensor<32x1xi32, #blocked1>
106113
%80 = tt.broadcast %79 : tensor<32x1xi1, #blocked1> -> tensor<32x256xi1, #blocked1>
107-
%81 = tt.load %arg12, %80, %cst_1 : tensor<32x256x!tt.ptr<f16>, #blocked1>
114+
%81 = tt.load %arg12, %80, %cst_1 {triton_intel_gpu.block_io = "row_major"} : tensor<32x256x!tt.ptr<f16>, #blocked1>
108115
%82 = ttg.convert_layout %77 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
109116
%83 = ttg.convert_layout %81 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
110117
%84 = tt.dot %82, %83, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
@@ -175,8 +182,8 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
175182
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
176183
// CHECK-NEXT: scf.yield
177184
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
178-
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<128x64xf16, #dot0>>
179-
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dot1>>
185+
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #dot0>>
186+
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x256xf16, #dot1>>
180187
%58 = tt.dot %56, %57, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #dot0> * tensor<64x256xf16, #dot1> -> tensor<128x256xf32, #dpas>
181188
%59 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>>
182189
%60 = tt.advance %arg12, [%c64_i32, %c0_i32] : <tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ def make_ttgir(mod, metadata, opt, properties):
280280
intel.passes.ttgpuir.add_accelerate_matmul(pm)
281281
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
282282
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
283-
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)
283+
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
284+
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, True)
284285

285286
passes.ttgpuir.add_fuse_nested_loops(pm)
286287
passes.ttgpuir.add_optimize_thread_locality(pm)

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ static ttg::DotOperandEncodingAttr allTransitiveUsesHaveDotEncoding(Value val) {
8484
}
8585

8686
/// Create a prefetch operation for the given load operation.
87-
static void createPrefetchOp(scf::ForOp &forOp, tt::LoadOp loadOp, Value ptr) {
87+
static void createPrefetchOp(scf::ForOp &forOp, tt::LoadOp loadOp) {
8888
OpBuilder builder(forOp);
8989
builder.setInsertionPoint(loadOp);
9090
auto prefetchOp = builder.create<ttgi::PrefetchOp>(
91-
loadOp->getLoc(), ptr, loadOp.getCache(), loadOp.getEvict(),
92-
loadOp.getIsVolatile());
91+
loadOp->getLoc(), loadOp.getPtr(), loadOp.getMask(), loadOp.getCache(),
92+
loadOp.getEvict(), loadOp.getIsVolatile());
9393

9494
// inherit attributes from the load operation
9595
auto attrs = loadOp->getAttrDictionary();
@@ -102,7 +102,7 @@ static void createPrefetchOps(scf::ForOp &forOp,
102102
assert(!loads.empty() && "Expecting at least one load operation");
103103
for (const LoadDotOperand &loadOperand : loads) {
104104
tt::LoadOp loadOp = loadOperand.load;
105-
createPrefetchOp(forOp, loadOp, loadOp.getPtr());
105+
createPrefetchOp(forOp, loadOp);
106106
}
107107
}
108108

@@ -132,6 +132,20 @@ static void collectOpsToPipeline(scf::ForOp forOp,
132132
if (!isBlockPtr && !supportRegularPtr)
133133
continue;
134134

135+
// Check if the memory is structed densely. If not, we do not prefetch it
136+
// to avoid pollute cache.
137+
Attribute blockIOAttr =
138+
loadOp->getAttr(mlir::triton::gpu::intel::TritonIntelGPUDialect::
139+
getBlockIOAttrName());
140+
if (!blockIOAttr) {
141+
LLVM_DEBUG({
142+
DBGS() << "skip the Load op to pipeline because the memory is not "
143+
"structured densely: ";
144+
DBGS() << " " << loadOp << "\n";
145+
});
146+
continue;
147+
}
148+
135149
std::optional<LoadDotOperand> loadWithDotOperand = loadDotOperand(loadOp);
136150
if (loadWithDotOperand.has_value())
137151
loadOps.push_back(loadWithDotOperand.value());
@@ -157,7 +171,7 @@ static Value getPredMask(RewriterBase &rewriter, Type typeLike,
157171
static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
158172
Value pred) {
159173
OpBuilder::InsertionGuard guard(rewriter);
160-
if (mlir::isMemoryEffectFree(op) || isa<ttgi::PrefetchOp>(op))
174+
if (mlir::isMemoryEffectFree(op))
161175
return op;
162176

163177
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
@@ -168,6 +182,14 @@ static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
168182
return loadOp;
169183
}
170184

185+
if (auto prefetchOp = dyn_cast<ttgi::PrefetchOp>(op)) {
186+
rewriter.setInsertionPoint(prefetchOp);
187+
Value mask = getPredMask(rewriter, prefetchOp.getPtr().getType(),
188+
prefetchOp.getMask(), pred);
189+
prefetchOp.getMaskMutable().assign(mask);
190+
return prefetchOp;
191+
}
192+
171193
llvm_unreachable("don't know how to predicate this operation");
172194
}
173195

0 commit comments

Comments
 (0)