Skip to content

Commit 7a815f8

Browse files
Limit prefetch to only densed memory
Signed-off-by: Whitney Tsang <[email protected]>
1 parent d9892a7 commit 7a815f8

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
100100
%74 = tt.splat %73 : i32 -> tensor<1x32xi32, #blocked>
101101
%75 = arith.cmpi slt, %33, %74 : tensor<1x32xi32, #blocked>
102102
%76 = tt.broadcast %75 : tensor<1x32xi1, #blocked> -> tensor<64x32xi1, #blocked>
103-
%77 = tt.load %arg11, %76, %cst_0 : tensor<64x32x!tt.ptr<f16>, #blocked>
103+
%77 = tt.load %arg11, %76, %cst_0 {triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #blocked>
104104
%78 = tt.splat %73 : i32 -> tensor<32x1xi32, #blocked1>
105105
%79 = arith.cmpi slt, %40, %78 : tensor<32x1xi32, #blocked1>
106106
%80 = tt.broadcast %79 : tensor<32x1xi1, #blocked1> -> tensor<32x256xi1, #blocked1>
107-
%81 = tt.load %arg12, %80, %cst_1 : tensor<32x256x!tt.ptr<f16>, #blocked1>
107+
%81 = tt.load %arg12, %80, %cst_1 {triton_intel_gpu.block_io = "row_major"} : tensor<32x256x!tt.ptr<f16>, #blocked1>
108108
%82 = ttg.convert_layout %77 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
109109
%83 = ttg.convert_layout %81 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
110110
%84 = tt.dot %82, %83, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
@@ -175,8 +175,8 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
175175
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
176176
// CHECK-NEXT: scf.yield
177177
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
178-
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<128x64xf16, #dot0>>
179-
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dot1>>
178+
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #dot0>>
179+
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x256xf16, #dot1>>
180180
%58 = tt.dot %56, %57, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #dot0> * tensor<64x256xf16, #dot1> -> tensor<128x256xf32, #dpas>
181181
%59 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>>
182182
%60 = tt.advance %arg12, [%c64_i32, %c0_i32] : <tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
@@ -248,8 +248,8 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
248248
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
249249
// CHECK-NEXT: scf.yield
250250
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
251-
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<128x64xf16, #dot0>>
252-
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dot1>>
251+
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #dot0>>
252+
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x256xf16, #dot1>>
253253
%58 = tt.dot %56, %57, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #dot0> * tensor<64x256xf16, #dot1> -> tensor<128x256xf32, #dpas>
254254
%102 = tt.addptr %arg8, %c4_i32 : !tt.ptr<i32>, i32
255255
%100 = arith.addi %c0_i32, %c4_i32 : i32

test/TritonIntelGPU/split-barrier.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
3333
// CHECK-NEXT: scf.yield
3434
%23:3 = scf.for %arg2 = %c0_i32 to %c64_i32 step %c64_i32 iter_args(%arg3 = %cst, %arg4 = %18, %arg5 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
3535
%55:3 = scf.for %arg9 = %c0_i32 to %c64_i32 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
36-
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<128x64xf16, #dot0>>
37-
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dot1>>
36+
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #dot0>>
37+
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x256xf16, #dot1>>
3838
%58 = tt.dot %56, %57, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #dot0> * tensor<64x256xf16, #dot1> -> tensor<128x256xf32, #dpas>
3939
%59 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>>
4040
%60 = tt.advance %arg12, [%c64_i32, %c0_i32] : <tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
@@ -79,8 +79,8 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
7979
// SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Subgroup> <Subgroup> <None>
8080
// CHECK-NEXT: scf.yield
8181
%23:3 = scf.for %arg9 = %c0_i32 to %c64_i32 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
82-
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<128x64xf16, #dot0>>
83-
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dot1>>
82+
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #dot0>>
83+
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x256xf16, #dot1>>
8484
%58 = tt.dot %56, %57, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #dot0> * tensor<64x256xf16, #dot1> -> tensor<128x256xf32, #dpas>
8585
%59 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>>
8686
%60 = tt.advance %arg12, [%c64_i32, %c0_i32] : <tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,20 @@ static void collectOpsToPipeline(scf::ForOp forOp,
132132
if (!isBlockPtr && !supportRegularPtr)
133133
continue;
134134

135+
// Check if the memory is structed densely. If not, we do not prefetch it
136+
// to avoid pollute cache.
137+
Attribute blockIOAttr =
138+
loadOp->getAttr(mlir::triton::gpu::intel::TritonIntelGPUDialect::
139+
getBlockIOAttrName());
140+
if (!blockIOAttr) {
141+
LLVM_DEBUG({
142+
DBGS() << "skip the Load op to pipeline because the memory is not "
143+
"structured densely: ";
144+
DBGS() << " " << loadOp << "\n";
145+
});
146+
continue;
147+
}
148+
135149
std::optional<LoadDotOperand> loadWithDotOperand = loadDotOperand(loadOp);
136150
if (loadWithDotOperand.has_value())
137151
loadOps.push_back(loadWithDotOperand.value());

0 commit comments

Comments
 (0)