diff --git a/test/TritonIntelGPU/loop-pipeline.mlir b/test/TritonIntelGPU/loop-pipeline.mlir index 4635fedea1..e16e1405a1 100644 --- a/test/TritonIntelGPU/loop-pipeline.mlir +++ b/test/TritonIntelGPU/loop-pipeline.mlir @@ -79,15 +79,19 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, %51 = arith.muli %arg7, %c32_i32 : i32 %52 = tt.splat %51 : i32 -> tensor<32x256xi32, #blocked1> // COM: There are 3 stages in loop pipelining, the first 2 prefetching stages are before the loop and the last one is inside the loop. + // CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}} + // CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]> + // CHECK: triton_intel_gpu.prefetch {{.*}}, %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> + // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> - // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> + // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>) : i32 { + // CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}} // CHECK: %[[VAL_106:.*]] = tt.addptr %[[VAL_94]], {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<64x32xi32, #[[$BLOCK_0]]> // CHECK: %[[VAL_107:.*]] = tt.addptr %[[VAL_95]], {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<32x256xi32, #[[$BLOCK_1]]> - // CHECK: triton_intel_gpu.prefetch %[[VAL_106]] {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> - // CHECK: triton_intel_gpu.prefetch %[[VAL_107]] {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> + // CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]> + // CHECK: triton_intel_gpu.prefetch %[[VAL_106]], %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> + // CHECK: triton_intel_gpu.prefetch %[[VAL_107]], {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> // CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_97]], {{.*}}, {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: %[[VAL_121:.*]] = ttg.convert_layout %[[VAL_116]] : tensor<64x32xf16, #[[$BLOCK_0]]> -> tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> @@ -166,12 +170,12 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32 %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> // CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr>>, !tt.ptr>>, !tt.ptr>>, !tt.ptr>>) // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> // CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]> // CHECK-NEXT: scf.yield %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { @@ -239,12 +243,12 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32 %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> // CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr>>, !tt.ptr>>, !tt.ptr>>, !tt.ptr>>) // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> // CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]> // CHECK-NEXT: scf.yield %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { @@ -302,18 +306,18 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup %12 = arith.extsi %arg3 : i32 to i64 // CHECK: scf.for %[[OUTER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (i32) // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR1]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR1]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR2]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR2]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR3]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR3]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: [[PTR4:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR4]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR4]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK-NEXT: scf.for %[[INNER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x128xf32, #blocked>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>) // CHECK: [[PTR5:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR5]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR5]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: [[PTR6:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR6]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR6]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %13 = scf.for %arg6 = %0 to %7 step %c448_i32 iter_args(%arg7 = %10) -> (i32) : i32 { %14 = arith.divsi %arg6, %11 : i32 %15 = arith.muli %14, %c8_i32 : i32 diff --git a/test/TritonIntelGPU/split-barrier.mlir b/test/TritonIntelGPU/split-barrier.mlir index db559c3e8b..a2db6e5c93 100644 --- a/test/TritonIntelGPU/split-barrier.mlir +++ b/test/TritonIntelGPU/split-barrier.mlir @@ -26,7 +26,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32 // WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive // SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> // CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]> // WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait // SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait @@ -73,7 +73,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32 // WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive // SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> // CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]> // WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait // SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp index 1788ba753b..d65be0221c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp @@ -6,6 +6,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -149,12 +150,13 @@ static void collectOpsToPipeline(scf::ForOp forOp, } } -/// Combine the current mask with the given predicate. -static Value getPredMask(RewriterBase &rewriter, Type typeLike, - Value currentMask, Value pred) { +/// Return a new mask of type of shape \p typeLike, and value combining the +/// current mask \p currentMask with the given predicate \p pred. +static Value computeNewMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { Location loc = pred.getLoc(); Value mask = pred; - Type maskType = tt::getI1SameShape(typeLike); + Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike)); if (isa(maskType)) mask = rewriter.create(loc, maskType, pred); @@ -167,18 +169,17 @@ static Value getPredMask(RewriterBase &rewriter, Type typeLike, static Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred) { OpBuilder::InsertionGuard guard(rewriter); - if (mlir::isMemoryEffectFree(op) || isa(op)) + if (mlir::isMemoryEffectFree(op)) return op; - if (auto loadOp = dyn_cast(op)) { - rewriter.setInsertionPoint(loadOp); - Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), - loadOp.getMask(), pred); - loadOp.getMaskMutable().assign(mask); - return loadOp; - } - - llvm_unreachable("don't know how to predicate this operation"); + return TypeSwitch(op) + .Case([&](auto op) { + rewriter.setInsertionPoint(op); + Value mask = + computeNewMask(rewriter, op.getPtr().getType(), op.getMask(), pred); + op.getMaskMutable().assign(mask); + return op; + }); } /// Helper to get the defining operation of a value.