Skip to content

Commit 40bb2fa

Browse files
chengjunluwhitneywhtsang
authored andcommitted
Support the tensor of pointer in the matmul loop pipelining.
1 parent 765fe50 commit 40bb2fa

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,28 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
7979
%51 = arith.muli %arg7, %c32_i32 : i32
8080
%52 = tt.splat %51 : i32 -> tensor<32x256xi32, #blocked1>
8181
// COM: There are 3 stages in loop pipelining, the first 2 prefetching stages are before the loop and the last one is inside the loop.
82-
// CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}}
83-
// CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]>
84-
// CHECK: triton_intel_gpu.prefetch {{.*}}, %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
85-
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
86-
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
87-
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
88-
// CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>) : i32 {
89-
// CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}}
82+
// CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}} : tensor<1x32xi32, #[[$BLOCK_0]]>
83+
// CHECK: %[[LOAD_MASK_2D:.*]] = tt.broadcast %[[LOAD_MASK]] : tensor<1x32xi1, #[[$BLOCK_0]]> -> tensor<64x32xi1, #[[$BLOCK_0]]>
84+
// CHECK: %[[LOOP_MASK:.*]] = tt.splat {{.*}} : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]>
85+
// CHECK: %[[PREFETCH_MASK:.*]] = arith.andi %[[LOOP_MASK]], %[[LOAD_MASK_2D]] : tensor<64x32xi1, #[[$BLOCK_0]]>
86+
// CHECK: triton_intel_gpu.prefetch {{.*}}, %[[PREFETCH_MASK]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
87+
// CHECK: %[[LOAD_MASK_2:.*]] = arith.cmpi slt, {{.*}} : tensor<32x1xi32, #[[$BLOCK_1]]>
88+
// CHECK: %[[LOAD_MASK_2D_2:.*]] = tt.broadcast %[[LOAD_MASK_2]] : tensor<32x1xi1, #[[$BLOCK_1]]> -> tensor<32x256xi1, #[[$BLOCK_1]]>
89+
// CHECK: %[[LOOP_MASK:.*]] = tt.splat {{.*}} : i1 -> tensor<32x256xi1, #[[$BLOCK_1]]>
90+
// CHECK: %[[PREFETCH_MASK:.*]] = arith.andi %[[LOOP_MASK]], %[[LOAD_MASK_2D_2]] : tensor<32x256xi1, #[[$BLOCK_1]]>
91+
// CHECK: triton_intel_gpu.prefetch {{.*}}, %[[PREFETCH_MASK]] {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
92+
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
93+
// CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
94+
// CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}, %[[VAL_98:.*]] = {{.*}}, %[[VAL_99:.*]] = {{.*}}, %[[VAL_100:.*]] = {{.*}}, %[[VAL_101:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi1, #[[$BLOCK_0]]>, tensor<64x32xi1, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi1, #[[$BLOCK_1]]>, tensor<32x256xi1, #[[$BLOCK_1]]>) : i32 {
9095
// CHECK: %[[VAL_106:.*]] = tt.addptr %[[VAL_94]], {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<64x32xi32, #[[$BLOCK_0]]>
9196
// CHECK: %[[VAL_107:.*]] = tt.addptr %[[VAL_95]], {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<32x256xi32, #[[$BLOCK_1]]>
92-
// CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]>
93-
// CHECK: triton_intel_gpu.prefetch %[[VAL_106]], %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
97+
// CHECK: triton_intel_gpu.prefetch %[[VAL_106]], {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
9498
// CHECK: triton_intel_gpu.prefetch %[[VAL_107]], {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
95-
// CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
96-
// CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_97]], {{.*}}, {{.*}} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
99+
// CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} {triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>
100+
// CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_99]], {{.*}}, {{.*}} {triton_intel_gpu.block_io = "row_major"} : tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
97101
// CHECK: %[[VAL_121:.*]] = ttg.convert_layout %[[VAL_116]] : tensor<64x32xf16, #[[$BLOCK_0]]> -> tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>
98102
// CHECK: %[[VAL_122:.*]] = ttg.convert_layout %[[VAL_120]] : tensor<32x256xf16, #[[$BLOCK_1]]> -> tensor<32x256xf16, #{{.*}}<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
99103
// CHECK: %[[VAL_123:.*]] = tt.dot %[[VAL_121]], %[[VAL_122]], %[[VAL_93]], inputPrecision = tf32 : tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #{{.*}}<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<64x256xf32, #[[$DPAS]]>
100-
// CHECK: scf.yield %[[VAL_123]], %[[VAL_106]], %[[VAL_107]], %[[VAL_94]], %[[VAL_95]] : tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr<f16>, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr<f16>, #[[$BLOCK_1]]>
101104
%53:3 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %38, %arg12 = %48) -> (tensor<64x256xf32, #dpas>, tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<32x256x!tt.ptr<f16>, #blocked1>) : i32 {
102105
%72 = arith.muli %arg9, %c32_i32 : i32
103106
%73 = arith.subi %arg5, %72 : i32

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,8 @@ def make_ttgir(mod, metadata, opt, properties):
310310
intel.passes.ttgpuir.add_accelerate_matmul(pm)
311311
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
312312
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
313-
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False, XPUBackend.get_split_barrier_scope(opt))
313+
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
314+
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, True, XPUBackend.get_split_barrier_scope(opt))
314315

315316
passes.ttgpuir.add_fuse_nested_loops(pm)
316317
passes.ttgpuir.add_optimize_thread_locality(pm)

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ static void createPrefetchOp(scf::ForOp &forOp, tt::LoadOp loadOp) {
8989
OpBuilder builder(forOp);
9090
builder.setInsertionPoint(loadOp);
9191
auto prefetchOp = builder.create<ttgi::PrefetchOp>(
92-
loadOp->getLoc(), loadOp.getPtr(), loadOp.getCache(), loadOp.getEvict(),
93-
loadOp.getIsVolatile());
92+
loadOp->getLoc(), loadOp.getPtr(), loadOp.getMask(), loadOp.getCache(),
93+
loadOp.getEvict(), loadOp.getIsVolatile());
9494

9595
// inherit attributes from the load operation
9696
auto attrs = loadOp->getAttrDictionary();

0 commit comments

Comments
 (0)