From 4384ad19d9a4f3c37a93243691aae8cd19e4e0c0 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 9 May 2025 17:55:18 +0000 Subject: [PATCH 01/16] Ensure block ptr is created with the same layout as the descriptor_load/descriptor_store operation that uses it Signed-off-by: Tiotto, Ettore --- .../Intel/TensorDescToBlockPointer/basic.mlir | 29 ++++++++---- .../Transforms/TensorDescToBlockPointer.cpp | 46 +++++++++++++++---- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir b/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir index f039040e0d..022c714abf 100644 --- a/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir +++ b/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir @@ -1,15 +1,19 @@ // RUN: triton-opt %s -triton-intel-tdesc-to-block-pointer | FileCheck %s -module { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32} { tt.func public @test_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) { %c1_i64 = arith.constant 1 : i64 %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %0 = arith.extsi %arg2 : i32 to i64 - %desc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > - %load = tt.descriptor_load %desc[%c8_i32, %c64_i32] : !tt.tensordesc> -> tensor<16x128xf32> + %desc1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > + %load1 = tt.descriptor_load %desc1[%c8_i32, %c64_i32] : !tt.tensordesc> -> tensor<16x128xf32> + %load2 = tt.descriptor_load %desc1[%c8_i32, %c64_i32] : !tt.tensordesc> -> tensor<16x128xf32, #blocked> tt.return } + // CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: tt.func public @test_load([[PARAM_0:%.+]]: !tt.ptr, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) { // CHECK-NOT: tt.make_tensor_descriptor // CHECK-NOT: tt.descriptor_load @@ -18,8 +22,10 @@ module { // CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32 // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 // CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 - // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > - // CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] : !tt.ptr> + // CHECK: [[TENSOR_PTR1:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > + // CHECK: [[LOAD1:%.+]] = tt.load [[TENSOR_PTR1]] : !tt.ptr> + // CHECK: [[TENSOR_PTR2:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > + // CHECK: [[LOAD2:%.+]] = tt.load [[TENSOR_PTR2]] : !tt.ptr> // CHECK: tt.return // CHECK: } @@ -28,9 +34,11 @@ module { %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %cst = arith.constant dense<1.000000e+00> : tensor<16x128xf32> + %cst1 = arith.constant dense<1.000000e+00> : tensor<16x128xf32, #blocked> %0 = arith.extsi %arg2 : i32 to i64 - %desc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > - tt.descriptor_store %desc[%c8_i32, %c64_i32], %cst : !tt.tensordesc>, tensor<16x128xf32> + %desc1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > + tt.descriptor_store %desc1[%c8_i32, %c64_i32], %cst : !tt.tensordesc>, tensor<16x128xf32> + tt.descriptor_store %desc1[%c8_i32, %c64_i32], %cst1 : !tt.tensordesc>, tensor<16x128xf32, #blocked> tt.return } // CHECK: tt.func public @test_store([[PARAM_0:%.+]]: !tt.ptr, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) { @@ -40,10 +48,13 @@ module { // CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32 // CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32 // CHECK-DAG: [[CST:%.+]] = arith.constant dense<1.000000e+00> : tensor<16x128xf32> + // CHECK-DAG: [[CST1:%.+]] = arith.constant dense<1.000000e+00> : tensor<16x128xf32, #[[$BLOCKED]]> // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 // CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 - // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > - // CHECK: tt.store [[TENSOR_PTR]], [[CST]] : !tt.ptr> + // CHECK: [[TENSOR_PTR1:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > + // CHECK: tt.store [[TENSOR_PTR1]], [[CST]] : !tt.ptr> + // CHECK: [[TENSOR_PTR2:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > + // CHECK: tt.store [[TENSOR_PTR2]], [[CST1]] : !tt.ptr> // CHECK: tt.return // CHECK: } } diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 5df0a0245a..184b69e596 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -1,4 +1,5 @@ #include "intel/include/Dialect/Triton/Transforms/Passes.h" +#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Verifier.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -6,7 +7,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" +// #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" @@ -73,6 +74,8 @@ struct TritonIntelTensorDescToBlockPointer }); finalize(); + + llvm::errs() << "moduleOp: " << *moduleOp << "\n"; assert(succeeded(verify(moduleOp)) && "Module verification failed"); } @@ -121,11 +124,19 @@ struct TritonIntelTensorDescToBlockPointer return (yieldedVal != blockArg); } + // Create a new block pointer if a suitable one doesn't already exist. + // Otherwise, return the existing one. The function takes the base, shape, + // strides, offsets, sizes of the block pointer to create/lookup and its + // tensor element type (to ensure the block pointer has the tensor layout). Value findOrCreateMakeTensorPtr(Location loc, Value base, ValueRange shape, ValueRange strides, ValueRange offsets, - ArrayRef sizes, OpBuilder &builder) { + ArrayRef sizes, + RankedTensorType tensorType, + OpBuilder &builder) { Block *block = builder.getInsertionBlock(); const Block::iterator insertPoint = builder.getInsertionPoint(); + auto ptrType = tt::PointerType::get( + tensorType, tt::TritonGEN::TritonGENMemorySpace::kCrossWorkgroup); auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) { if (auto makeTensorPtrOp = dyn_cast(op)) { @@ -138,7 +149,9 @@ struct TritonIntelTensorDescToBlockPointer } return true; }; - return makeTensorPtrOp.getBase() == base && + + return makeTensorPtrOp.getType() == ptrType && + makeTensorPtrOp.getBase() == base && makeTensorPtrOp.getShape() == shape && makeTensorPtrOp.getStrides() == strides && makeTensorPtrOp.getOffsets() == offsets && @@ -147,10 +160,16 @@ struct TritonIntelTensorDescToBlockPointer return false; }); + auto makeTensorPtrOp = [&]() { + Value makeTensorPtr = builder.create( + loc, base, shape, strides, offsets, sizes, + builder.getDenseI32ArrayAttr({1, 0})); + makeTensorPtr.setType(ptrType); + return makeTensorPtr; + }; + return (it != insertPoint) ? cast(*it) - : builder.createOrFold( - loc, base, shape, strides, offsets, sizes, - builder.getDenseI32ArrayAttr({1, 0})); + : makeTensorPtrOp(); } template shapes, strides, offsets; SmallVector sizes; @@ -193,16 +217,22 @@ struct TritonIntelTensorDescToBlockPointer sizes.push_back(static_cast(size)); } + constexpr bool isLoad = std::is_same_v; + RankedTensorType tensorType; + if constexpr (isLoad) + tensorType = op.getResult().getType(); + else + tensorType = op.getSrc().getType(); + Value makeTensorPtrOp = findOrCreateMakeTensorPtr(loc, makeTensorDescOp.getBase(), shapes, - strides, offsets, sizes, builder); + strides, offsets, sizes, tensorType, builder); LLVM_DEBUG({ llvm::dbgs() << "With:\n"; llvm::dbgs().indent(2) << makeTensorPtrOp << "\n"; }); - constexpr bool isLoad = std::is_same_v; if constexpr (isLoad) { auto loadOp = builder.createOrFold( loc, makeTensorPtrOp, op.getCache(), op.getEvict(), From f0ce91c5ca3f209ce9db5dd7992d6800b09cb443 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 9 May 2025 18:15:30 +0000 Subject: [PATCH 02/16] Remove naked print and unnecessary headers Signed-off-by: Tiotto, Ettore --- .../Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 184b69e596..09d673f133 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -7,8 +7,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" -// #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "triton-intel-tdesc-to-block-pointer" @@ -74,8 +72,6 @@ struct TritonIntelTensorDescToBlockPointer }); finalize(); - - llvm::errs() << "moduleOp: " << *moduleOp << "\n"; assert(succeeded(verify(moduleOp)) && "Module verification failed"); } From 3543e6e55c12853c655551657723db523f59ccd2 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 14 May 2025 20:49:06 +0000 Subject: [PATCH 03/16] WIP: TensorDescToBlockPtr updates Signed-off-by: Tiotto, Ettore --- .../Transforms/TensorDescToBlockPointer.cpp | 153 +++++++++++++++++- .../RemoveLayoutConversions.cpp | 63 ++++++-- 2 files changed, 198 insertions(+), 18 deletions(-) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 09d673f133..495f15b088 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -1,5 +1,6 @@ #include "intel/include/Dialect/Triton/Transforms/Passes.h" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" +#include "intel/include/Utils/Utility.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Verifier.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -61,6 +62,14 @@ struct TritonIntelTensorDescToBlockPointer moduleOp->walk([&](Operation *op) { return TypeSwitch(op) +#if 1 + .Case([&](auto makeTensorDescOp) { + if (failed(rewriteMakeTensorDescriptorOp(makeTensorDescOp))) + makeTensorDescOp->emitRemark( + "TritonIntelTensorDescToBlockPointer: Failed to rewrite"); + return WalkResult::advance(); + }) +#endif .Case( [&](auto loadOrStoreOp) { if (failed(rewriteDescriptorLoadOrStoreOp(loadOrStoreOp))) @@ -88,7 +97,7 @@ struct TritonIntelTensorDescToBlockPointer unsigned numIVs = forOp.getNumInductionVars(); int initArgIdx = blockArg.getArgNumber() - numIVs; if (isModifiedInLoop(forOp, blockArg)) { - LLVM_DEBUG(llvm::dbgs() << blockArg << "is loop variant"); + LLVM_DEBUG(llvm::dbgs() << blockArg << " is loop variant\n"); return nullptr; } Operation::operand_range initArgs = forOp.getInitArgs(); @@ -168,11 +177,110 @@ struct TritonIntelTensorDescToBlockPointer : makeTensorPtrOp(); } + // Create a new block pointer if a suitable one doesn't already exist. + // Otherwise, return the existing one. The function takes the base, shape, + // strides, offsets, sizes of the block pointer to create/lookup and its + // tensor element type (to ensure the block pointer has the tensor layout). + Value findOrCreateMakeTensorPtrTmp(Location loc, Value base, ValueRange shape, + ValueRange strides, ValueRange offsets, + ArrayRef sizes, + OpBuilder &builder) { + Block *block = builder.getInsertionBlock(); + const Block::iterator insertPoint = builder.getInsertionPoint(); + auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) { + if (auto makeTensorPtrOp = dyn_cast(op)) { + triton::PointerType resType = makeTensorPtrOp.getResult().getType(); + auto tensorType = cast(resType.getPointeeType()); + auto sameShape = [](ArrayRef arr1, ArrayRef arr2) { + for (auto [dim1, dim2] : llvm::zip(arr1, arr2)) { + if (dim1 != dim2) + return false; + } + return true; + }; + + return makeTensorPtrOp.getBase() == base && + makeTensorPtrOp.getShape() == shape && + makeTensorPtrOp.getStrides() == strides && + makeTensorPtrOp.getOffsets() == offsets && + sameShape(tensorType.getShape(), sizes); + } + return false; + }); + + auto makeTensorPtrOp = [&]() { + Value makeTensorPtr = builder.create( + loc, base, shape, strides, offsets, sizes, + builder.getDenseI32ArrayAttr({1, 0})); + return makeTensorPtr; + }; + + return (it != insertPoint) ? cast(*it) + : makeTensorPtrOp(); + } + + LogicalResult rewriteMakeTensorDescriptorOp(tt::MakeTensorDescOp op) { + assert(op && "Expecting a valid operation"); + LLVM_DEBUG(llvm::dbgs() << "Rewriting: " << op << "\n"); + + OpBuilder builder(op); + Location loc = op.getLoc(); + tt::TensorDescType tDescType = op.getType(); + + // Create a new block pointer if a suitable one doesn't already exist. + SmallVector shapes, strides, offsets; + SmallVector sizes; + for (const auto [shape, stride, size] : + llvm::zip(op.getShape(), op.getStrides(), + tDescType.getBlockType().getShape())) { + shapes.push_back(findOrCreateCast( + loc, shape, builder.getIntegerType(shapeAndStridesBitwidth), + builder)); + strides.push_back(findOrCreateCast( + loc, stride, builder.getIntegerType(shapeAndStridesBitwidth), + builder)); + Value zero = + tt::intel::findOrCreateIntConstant(loc, 0, offsetBitwidth, builder); + offsets.push_back(zero); + sizes.push_back(static_cast(size)); + } + + Value tensorPtr = findOrCreateMakeTensorPtrTmp( + loc, op.getBase(), shapes, strides, offsets, sizes, builder); + LLVM_DEBUG({ + llvm::dbgs() << "With:\n"; + llvm::dbgs().indent(2) << tensorPtr << "\n"; + }); + + op.replaceAllUsesWith(tensorPtr); + cleanUp.insert(op); + + for (Operation *user : tensorPtr.getUsers()) { + if (auto forOp = dyn_cast(user)) { + for (auto it : + llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs(), + forOp.getResults(), forOp.getYieldedValues())) { + Value initArg = std::get<0>(it), rgnInitArg = std::get<1>(it), + loopRes = std::get<2>(it), yieldVal = std::get<3>(it); + assert(rgnInitArg.getType() == loopRes.getType() && + rgnInitArg.getType() == yieldVal.getType() && "Type mismatch"); + if (rgnInitArg.getType() != initArg.getType()) { + rgnInitArg.setType(initArg.getType()); + loopRes.setType(initArg.getType()); + yieldVal.setType(initArg.getType()); + } + } + } + } + + return success(); + } + template ::value, bool> = true> - LogicalResult rewriteDescriptorLoadOrStoreOp(OpTy op) { + LogicalResult rewriteDescriptorLoadOrStoreOpOld(OpTy op) { assert(op && "Expecting a valid operation"); LLVM_DEBUG(llvm::dbgs() << "Rewriting: " << op << "\n"); @@ -191,11 +299,6 @@ struct TritonIntelTensorDescToBlockPointer LLVM_DEBUG(llvm::dbgs() << "which has tdesc: " << makeTensorDescOp << "\n"); - auto createPointerType = [](RankedTensorType tensorType) { - return tt::PointerType::get( - tensorType, tt::TritonGEN::TritonGENMemorySpace::kCrossWorkgroup); - }; - // Create a new block pointer if a suitable one doesn't already exist. SmallVector shapes, strides, offsets; SmallVector sizes; @@ -248,6 +351,42 @@ struct TritonIntelTensorDescToBlockPointer return success(); } + template ::value, + bool> = true> + LogicalResult rewriteDescriptorLoadOrStoreOp(OpTy op) { + assert(op && "Expecting a valid operation"); + LLVM_DEBUG(llvm::dbgs() << "Rewriting: " << op << "\n"); + + OpBuilder builder(op); + Location loc = op.getLoc(); + Value ptr = op.getOperand(0); + assert(triton::isTensorPointerType(ptr.getType()) && + "Expecting a block ptr"); + + ptr = + builder.create(loc, ptr.getType(), ptr, op.getIndices()); + + constexpr bool isLoad = std::is_same_v; + if constexpr (isLoad) { + auto loadOp = builder.createOrFold(loc, ptr, op.getCache(), + op.getEvict(), + /*volatile*/ false); + LLVM_DEBUG(llvm::dbgs().indent(2) << loadOp << "\n"); + op.replaceAllUsesWith(loadOp); + } else { + [[maybe_unused]] auto storeOp = builder.createOrFold( + loc, ptr, op.getSrc(), tt::CacheModifier::NONE, + tt::EvictionPolicy::NORMAL); + LLVM_DEBUG(llvm::dbgs().indent(2) << storeOp << "\n"); + } + + cleanUp.insert(op); + + return success(); + } + void finalize() { // Cleanup unused operations. bool erasedOperation; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 4a5f93e08c..04b022052d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -692,6 +692,29 @@ void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { assertOp->setOperand(0, newOperand); } +// Recursively update the operands in a chain of AdvanceOps, after setting the +// pointer operand of the first one. +static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp, + Value data, OpBuilder &rewriter) { + auto newAdvanceOp = + rewriter.create(advanceOp.getLoc(), makeTensorPtrOp.getType(), + makeTensorPtrOp, advanceOp.getOffsets()); + + SmallVector advanceOpUsers(advanceOp->getUsers()); + for (Operation *user : advanceOpUsers) { + if (auto storeOp = dyn_cast(user)) { + // Update the StoreOp operands. + storeOp.setOperand(0, newAdvanceOp); + storeOp.setOperand(1, data); + } else if (auto nextAdvanceOp = dyn_cast(user)) { + // Recursive call to handle the next AdvanceOp in the chain. + updateAdvanceOpChain(nextAdvanceOp, makeTensorPtrOp, data, rewriter); + } else { + llvm_unreachable("Unexpected user of AdvanceOp"); + } + } +} + bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { // Disable 2D block store on LTS. if (!storeOp->getParentOfType()->hasAttr( @@ -705,13 +728,16 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { if (!isTensorPointerType(ptr.getType())) return false; - // 2D block store are preceeded by a MakeTensorPtrOp - auto makeTensorPtrOp = ptr.getDefiningOp(); - if (!makeTensorPtrOp) - return false; + // Locate the operation that created the block pointer. + Operation *defOp = ptr.getDefiningOp(); + while (auto advanceOp = dyn_cast(defOp)) + defOp = advanceOp.getPtr().getDefiningOp(); + assert(isa(defOp) && + "MakeTensorPtrOp should be the only op that creates a tensor pointer"); + auto makeTensorPtrOp = cast(defOp); - // DPAS encoding have to be propagate if conversion from DPAS to - // other has been done before. + // DPAS encoding have to be propagated if conversion from a DPAS layout to + // another layout has been done before. auto convertOp = storeOp.getValue().getDefiningOp(); PointerType newPtrType; Attribute encoding; @@ -758,21 +784,36 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { encoding = convertOpSrcType.getEncoding(); } - // We create a new MakeTensorPtrOp with the new data type. + // Create a new MakeTensorPtrOp with the new layout. OpBuilder rewriter(makeTensorPtrOp); - Value newStorePtr = rewriter.create( + Value newMakeTensorPtrOp = rewriter.create( makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(), makeTensorPtrOp.getShape(), makeTensorPtrOp.getStrides(), - makeTensorPtrOp.getOffsets(), rewriter.getDenseI32ArrayAttr({1, 0})); - + makeTensorPtrOp.getOffsets(), makeTensorPtrOp.getOrderAttr()); + +#if 1 + // Update the store operation with the new layout. + for (Operation *user : makeTensorPtrOp->getUsers()) { + if (auto storeOp = dyn_cast(user)) { + storeOp.setOperand(0, newMakeTensorPtrOp); + storeOp.setOperand(1, getValueAs(value, encoding)); + } else if (auto advanceOp = dyn_cast(user)) { + updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, + getValueAs(value, encoding), rewriter); + } else { + llvm_unreachable("Unexpected user of MakeTensorPtrOp"); + } + } +#else // The encoding of the StoreOp is updated with the new // operands: // - the Ptr created by the MakeTensorPtrOp with the new data // type // - the forwarded DPAS encoding. Value newOperand = getValueAs(value, encoding); - storeOp.setOperand(0, newStorePtr); + storeOp.setOperand(0, newMakeTensorPtrOp); storeOp.setOperand(1, newOperand); +#endif // If the DPAS encoding is forwarded, we do not need the // convertOp anymore if the convertOp was only used by the From 38ef6c353f14a8abe57ea4436a7d38cb648a0d69 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 15 May 2025 18:22:28 +0000 Subject: [PATCH 04/16] WIP: RemoveLAuoyutConversion improvement for tt.advance operation Signed-off-by: Tiotto, Ettore --- .../RemoveLayoutConversions.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 04b022052d..6372a95c0f 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -695,7 +695,8 @@ void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { // Recursively update the operands in a chain of AdvanceOps, after setting the // pointer operand of the first one. static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp, - Value data, OpBuilder &rewriter) { + Value data) { + OpBuilder rewriter(advanceOp); auto newAdvanceOp = rewriter.create(advanceOp.getLoc(), makeTensorPtrOp.getType(), makeTensorPtrOp, advanceOp.getOffsets()); @@ -708,7 +709,7 @@ static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp, storeOp.setOperand(1, data); } else if (auto nextAdvanceOp = dyn_cast(user)) { // Recursive call to handle the next AdvanceOp in the chain. - updateAdvanceOpChain(nextAdvanceOp, makeTensorPtrOp, data, rewriter); + updateAdvanceOpChain(nextAdvanceOp, makeTensorPtrOp, data); } else { llvm_unreachable("Unexpected user of AdvanceOp"); } @@ -799,7 +800,7 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { storeOp.setOperand(1, getValueAs(value, encoding)); } else if (auto advanceOp = dyn_cast(user)) { updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, - getValueAs(value, encoding), rewriter); + getValueAs(value, encoding)); } else { llvm_unreachable("Unexpected user of MakeTensorPtrOp"); } @@ -1648,6 +1649,7 @@ class TritonIntelGPURemoveLayoutConversionsPass LLVM_DEBUG({ DBGS() << "Module after propagating layouts forward:\n"; m.dump(); + assert(succeeded(verify(m)) && "Module verification failed"); }); cleanupConvertOps(); @@ -1658,6 +1660,7 @@ class TritonIntelGPURemoveLayoutConversionsPass LLVM_DEBUG({ DBGS() << "Module after backward remat:\n"; m.dump(); + assert(succeeded(verify(m)) && "Module verification failed"); }); // Cleanup dummy converts created during backward remat. @@ -1669,6 +1672,7 @@ class TritonIntelGPURemoveLayoutConversionsPass LLVM_DEBUG({ DBGS() << "Module after hoisting converts:\n"; m.dump(); + assert(succeeded(verify(m)) && "Module verification failed"); }); // 4. Apply clean up patterns to remove remove dead convert and dead code @@ -1684,6 +1688,7 @@ class TritonIntelGPURemoveLayoutConversionsPass LLVM_DEBUG({ DBGS() << "Module after final cleanups:\n"; m.dump(); + assert(succeeded(verify(m)) && "Module verification failed"); }); } }; From 9eb16ec1c1f6f97e4a4dd499fac7e84612d39790 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 15 May 2025 20:55:47 +0000 Subject: [PATCH 05/16] WIP: TensorDescToBlockPtr updates Signed-off-by: Tiotto, Ettore --- .../Intel/TensorDescToBlockPointer/basic.mlir | 19 ++-- .../Intel/TensorDescToBlockPointer/loop.mlir | 90 +++++++++++-------- .../backward_combine_dpas_dot_layout.mlir | 5 +- .../RemoveLayoutConversions.cpp | 13 +-- 4 files changed, 69 insertions(+), 58 deletions(-) diff --git a/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir b/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir index 022c714abf..c6b8717eb2 100644 --- a/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir +++ b/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir @@ -1,7 +1,5 @@ // RUN: triton-opt %s -triton-intel-tdesc-to-block-pointer | FileCheck %s -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> - module attributes {"ttg.num-warps" = 4 : i32} { tt.func public @test_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) { %c1_i64 = arith.constant 1 : i64 @@ -10,22 +8,20 @@ module attributes {"ttg.num-warps" = 4 : i32} { %0 = arith.extsi %arg2 : i32 to i64 %desc1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > %load1 = tt.descriptor_load %desc1[%c8_i32, %c64_i32] : !tt.tensordesc> -> tensor<16x128xf32> - %load2 = tt.descriptor_load %desc1[%c8_i32, %c64_i32] : !tt.tensordesc> -> tensor<16x128xf32, #blocked> tt.return } - // CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: tt.func public @test_load([[PARAM_0:%.+]]: !tt.ptr, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) { // CHECK-NOT: tt.make_tensor_descriptor // CHECK-NOT: tt.descriptor_load + // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 // CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32 // CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32 // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 // CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 - // CHECK: [[TENSOR_PTR1:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > + // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > + // CHECK: [[TENSOR_PTR1:%.+]] = tt.advance [[TENSOR_PTR]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] : > // CHECK: [[LOAD1:%.+]] = tt.load [[TENSOR_PTR1]] : !tt.ptr> - // CHECK: [[TENSOR_PTR2:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > - // CHECK: [[LOAD2:%.+]] = tt.load [[TENSOR_PTR2]] : !tt.ptr> // CHECK: tt.return // CHECK: } @@ -34,27 +30,24 @@ module attributes {"ttg.num-warps" = 4 : i32} { %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %cst = arith.constant dense<1.000000e+00> : tensor<16x128xf32> - %cst1 = arith.constant dense<1.000000e+00> : tensor<16x128xf32, #blocked> %0 = arith.extsi %arg2 : i32 to i64 %desc1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > tt.descriptor_store %desc1[%c8_i32, %c64_i32], %cst : !tt.tensordesc>, tensor<16x128xf32> - tt.descriptor_store %desc1[%c8_i32, %c64_i32], %cst1 : !tt.tensordesc>, tensor<16x128xf32, #blocked> tt.return } // CHECK: tt.func public @test_store([[PARAM_0:%.+]]: !tt.ptr, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) { // CHECK-NOT: tt.make_tensor_descriptor // CHECK-NOT: tt.descriptor_store + // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 // CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32 // CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32 // CHECK-DAG: [[CST:%.+]] = arith.constant dense<1.000000e+00> : tensor<16x128xf32> - // CHECK-DAG: [[CST1:%.+]] = arith.constant dense<1.000000e+00> : tensor<16x128xf32, #[[$BLOCKED]]> // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 // CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 - // CHECK: [[TENSOR_PTR1:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > + // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > + // CHECK: [[TENSOR_PTR1:%.+]] = tt.advance [[TENSOR_PTR]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] : > // CHECK: tt.store [[TENSOR_PTR1]], [[CST]] : !tt.ptr> - // CHECK: [[TENSOR_PTR2:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : > - // CHECK: tt.store [[TENSOR_PTR2]], [[CST1]] : !tt.ptr> // CHECK: tt.return // CHECK: } } diff --git a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir index 608289608a..1748a6d6c4 100644 --- a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir +++ b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir @@ -25,15 +25,16 @@ module { // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 // CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32 // CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.000000e+00> : tensor<16x32xf16> - // CHECK-DAG: [[EXTSI_PARAM_2a:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 - // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} iter_args([[VAR_arg1:%.+]] = {{.*}}, [[VAR_arg2:%.+]] = [[CST]]) -> (!tt.tensordesc>, tensor<16x32xf16>) { - // CHECK-DAG: [[IDX_CAST_1:%.+]] = arith.index_cast [[IV]] : index to i32 - // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 - // CHECK-DAG: [[EXTSI_PARAM_2b:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 - // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2b]]], {{\[}}[[EXTSI_PARAM_2a]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[IDX_CAST_1]]] {{.*}} : > - // CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] : !tt.ptr> + // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 + // CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 + // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 + // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > + // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} iter_args([[VAR_arg1:%.+]] = [[TENSOR_PTR]], [[VAR_arg2:%.+]] = [[CST]]) -> (!tt.ptr>, tensor<16x32xf16>) { + // CHECK: [[IDX_CAST:%.+]] = arith.index_cast [[IV]] : index to i32 + // CHECK: [[TENSOR_PTR_1:%.+]] = tt.advance [[VAR_arg1]], {{\[}}[[CST_8_i32]], [[IDX_CAST]]] : > + // CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR_1]] : !tt.ptr> // CHECK: [[ADD:%.+]] = arith.addf [[VAR_arg2]], [[LOAD]] : tensor<16x32xf16> - // CHECK: scf.yield {{.*}}, [[ADD]] : !tt.tensordesc>, tensor<16x32xf16> + // CHECK: scf.yield [[VAR_arg1]], [[ADD]] : !tt.ptr>, tensor<16x32xf16> // CHECK: } // CHECK: tt.return // CHECK: } @@ -60,12 +61,27 @@ module { tt.return } // CHECK: tt.func public @load_in_loop2({{.*}}) { - // CHECK-NOT: tt.make_tensor_ptr - // CHECK-NOT: tt.load - // CHECK: tt.make_tensor_descriptor - // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} -> (!tt.tensordesc>, tensor<16x32xf16>) { - // CHECK: tt.descriptor_load - // CHECK: tt.make_tensor_descriptor + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK-NOT: tt.descriptor_load + // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 + // CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32 + // CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.000000e+00> : tensor<16x32xf16> + // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 + // CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 + // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 + // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > + // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} iter_args([[VAR_arg1:%.+]] = [[TENSOR_PTR]], [[VAR_arg2:%.+]] = [[CST]]) -> (!tt.ptr>, tensor<16x32xf16>) { + // CHECK: [[IDX_CAST:%.+]] = arith.index_cast [[IV]] : index to i32 + // CHECK: [[TENSOR_PTR_1:%.+]] = tt.advance [[VAR_arg1]], {{\[}}[[CST_8_i32]], [[IDX_CAST]]] : > + // CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR_1]] : !tt.ptr> + // CHECK: [[ADD:%.+]] = arith.addf [[VAR_arg2]], [[LOAD]] : tensor<16x32xf16> + // CHECK-DAG: [[EXTSI_PARAM_1a:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 + // CHECK-DAG: [[EXTSI_PARAM_2a:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 + // CHECK-DAG: [[CST_0_i32_1:%.+]] = arith.constant 0 : i32 + // CHECK: [[TENSOR_PTR2:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_2a]], [[EXTSI_PARAM_1a]]], {{\[}}[[CST_1_i64]], [[EXTSI_PARAM_2]]], {{\[}}[[CST_0_i32_1]], [[CST_0_i32_1]]] {{.*}} : > + // CHECK: [[CMP:%.+]] = arith.cmpi eq, [[IDX_CAST]], [[CST_8_i32]] : i32 + // CHECK: [[TENSOR_PTR3:%.+]] = arith.select [[CMP]], [[VAR_arg1]], [[TENSOR_PTR:%.+]] : !tt.ptr> + // CHECK: scf.yield [[TENSOR_PTR3]], [[ADD]] : !tt.ptr>, tensor<16x32xf16> // CHECK: } // CHECK: tt.return // CHECK: } @@ -87,10 +103,12 @@ module { tt.return } // CHECK: tt.func public @load_uses_loop_result({{.*}}) { - // CHECK-NOT: tt.make_tensor_ptr // CHECK-NOT: tt.load - // CHECK: tt.make_tensor_descriptor - // CHECK: tt.descriptor_load + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr {{.*}} : > + // CHECK: [[FOR_RES:%.+]] = scf.for [[IV:%.+]] = {{.*}} iter_args([[VAR_arg1:%.+]] = [[TENSOR_PTR]]) -> (!tt.ptr>) + // CHECK: [[TENSOR_PTR1:%.+]] = tt.advance [[FOR_RES]], {{.*}} : > + // CHECK: tt.load [[TENSOR_PTR1]] : !tt.ptr> // CHECK: tt.return // CHECK: } @@ -115,18 +133,19 @@ module { // CHECK: tt.func public @store_in_loop1([[PARAM_0:%.+]]: !tt.ptr, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) { // CHECK-NOT: tt.make_tensor_descriptor // CHECK-NOT: tt.descriptor_store + // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 // CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32 // CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.000000e+00> : tensor<16x32xf16> - // CHECK-DAG: [[EXTSI_PARAM_2a:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 - // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} iter_args([[VAR_arg1:%.+]] = {{.*}}, [[VAR_arg2:%.+]] = [[CST]]) -> (!tt.tensordesc>, tensor<16x32xf16>) { - // CHECK-DAG: [[IDX_CAST_1:%.+]] = arith.index_cast [[IV]] : index to i32 - // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 - // CHECK-DAG: [[EXTSI_PARAM_2b:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 - // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2b]]], {{\[}}[[EXTSI_PARAM_2a]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[IDX_CAST_1]]] {{.*}} : > - // CHECK: tt.store [[TENSOR_PTR]], [[VAR_arg2]] : !tt.ptr> + // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 + // CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 + // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > + // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} iter_args([[VAR_arg1:%.+]] = [[TENSOR_PTR]], [[VAR_arg2:%.+]] = [[CST]]) -> (!tt.ptr>, tensor<16x32xf16>) { + // CHECK: [[IDX_CAST_1:%.+]] = arith.index_cast [[IV]] : index to i32 + // CHECK: [[TENSOR_PTR_1:%.+]] = tt.advance [[VAR_arg1]], {{\[}}[[CST_8_i32]], [[IDX_CAST]]] : > + // CHECK: tt.store [[TENSOR_PTR_1]], [[VAR_arg2]] : !tt.ptr> // CHECK: [[ADD:%.+]] = arith.addf [[VAR_arg2]], [[CST]] : tensor<16x32xf16> - // CHECK: scf.yield {{.*}}, [[ADD]] : !tt.tensordesc>, tensor<16x32xf16> + // CHECK: scf.yield [[VAR_arg1]], [[ADD]] : !tt.ptr>, tensor<16x32xf16> // CHECK: } // CHECK: tt.return // CHECK: } @@ -153,12 +172,12 @@ module { tt.return } // CHECK: tt.func public @store_in_loop2({{.*}}) { - // CHECK-NOT: tt.make_tensor_ptr - // CHECK-NOT: tt.store - // CHECK: tt.make_tensor_descriptor - // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} -> (!tt.tensordesc>, tensor<16x32xf16>) { - // CHECK: tt.descriptor_store - // CHECK: tt.make_tensor_descriptor + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK-NOT: tt.descriptor_store + // CHECK: tt.make_tensor_ptr + // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} -> (!tt.ptr>, tensor<16x32xf16>) { + // CHECK: tt.advance + // CHECK: tt.store // CHECK: } // CHECK: tt.return // CHECK: } @@ -181,10 +200,11 @@ module { tt.return } // CHECK: tt.func public @store_uses_loop_result({{.*}}) { - // CHECK-NOT: tt.make_tensor_ptr - // CHECK-NOT: tt.store - // CHECK: tt.make_tensor_descriptor - // CHECK: tt.descriptor_store + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK-NOT: tt.descriptor_store + // CHECK: tt.make_tensor_ptr + // CHECK: tt.advance + // CHECK: tt.store // CHECK: tt.return // CHECK: } diff --git a/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir b/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir index 26338c9af3..427066b097 100644 --- a/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir +++ b/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir @@ -12,6 +12,7 @@ #dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> #dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { + // CHECK: matmul_kernel_with_block_pointers tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i64, %arg7: i32, %arg8: i64) { %c8_i32 = arith.constant 8 : i32 %c64_i32 = arith.constant 64 : i32 @@ -66,7 +67,6 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, } %24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas> %25 = ttg.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1> - %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%arg8, %c1_i64], [%14, %19] {order = array} : > tt.store %27, %25 {boundaryCheck = array} : !tt.ptr> tt.return @@ -86,6 +86,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, #dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> #dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { + // CHECK: matmul_kernel_with_block_pointers tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { %c8_i32 = arith.constant 8 : i32 %c64_i32 = arith.constant 64 : i32 @@ -154,6 +155,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> #dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { + // CHECK: matmul_kernel_with_block_pointers tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg13: !tt.ptr, %arg14: !tt.ptr) { %c8_i32 = arith.constant 8 : i32 %c64_i32 = arith.constant 64 : i32 @@ -233,6 +235,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> #dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { + // CHECK: matmul_kernel_with_block_pointers tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { %c1_i64 = arith.constant 1 : i64 %c0_i32 = arith.constant 0 : i32 diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 6372a95c0f..c6063b9d66 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -710,8 +710,6 @@ static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp, } else if (auto nextAdvanceOp = dyn_cast(user)) { // Recursive call to handle the next AdvanceOp in the chain. updateAdvanceOpChain(nextAdvanceOp, makeTensorPtrOp, data); - } else { - llvm_unreachable("Unexpected user of AdvanceOp"); } } } @@ -794,22 +792,19 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { #if 1 // Update the store operation with the new layout. - for (Operation *user : makeTensorPtrOp->getUsers()) { + SmallVector makeTensorPtrOpUsers(makeTensorPtrOp->getUsers()); + for (Operation *user : makeTensorPtrOpUsers) { if (auto storeOp = dyn_cast(user)) { storeOp.setOperand(0, newMakeTensorPtrOp); storeOp.setOperand(1, getValueAs(value, encoding)); } else if (auto advanceOp = dyn_cast(user)) { updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, getValueAs(value, encoding)); - } else { - llvm_unreachable("Unexpected user of MakeTensorPtrOp"); } } #else - // The encoding of the StoreOp is updated with the new - // operands: - // - the Ptr created by the MakeTensorPtrOp with the new data - // type + // The encoding of the StoreOp is updated with the new operands: + // - the Ptr created by the MakeTensorPtrOp with the new data type // - the forwarded DPAS encoding. Value newOperand = getValueAs(value, encoding); storeOp.setOperand(0, newMakeTensorPtrOp); From 7f9bbc9fff0531916fdd0136f6fbd1ed5e64f6db Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 15 May 2025 21:18:31 +0000 Subject: [PATCH 06/16] WIP: TensorDescToBlockPtr updates Signed-off-by: Tiotto, Ettore --- .../Transforms/TensorDescToBlockPointer.cpp | 174 +----------------- .../RemoveLayoutConversions.cpp | 25 +-- 2 files changed, 9 insertions(+), 190 deletions(-) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 495f15b088..0544ccee42 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -62,14 +62,12 @@ struct TritonIntelTensorDescToBlockPointer moduleOp->walk([&](Operation *op) { return TypeSwitch(op) -#if 1 .Case([&](auto makeTensorDescOp) { if (failed(rewriteMakeTensorDescriptorOp(makeTensorDescOp))) makeTensorDescOp->emitRemark( "TritonIntelTensorDescToBlockPointer: Failed to rewrite"); return WalkResult::advance(); }) -#endif .Case( [&](auto loadOrStoreOp) { if (failed(rewriteDescriptorLoadOrStoreOp(loadOrStoreOp))) @@ -85,106 +83,13 @@ struct TritonIntelTensorDescToBlockPointer } private: - tt::MakeTensorDescOp getMakeTensorDescOp(Value base) const { - assert(base && isa(base.getType()) && - "Expecting tensor desc"); - - Operation *defOp = base.getDefiningOp(); - if (!defOp) { - BlockArgument blockArg = cast(base); - Operation *parentOp = blockArg.getOwner()->getParentOp(); - if (scf::ForOp forOp = dyn_cast(parentOp)) { - unsigned numIVs = forOp.getNumInductionVars(); - int initArgIdx = blockArg.getArgNumber() - numIVs; - if (isModifiedInLoop(forOp, blockArg)) { - LLVM_DEBUG(llvm::dbgs() << blockArg << " is loop variant\n"); - return nullptr; - } - Operation::operand_range initArgs = forOp.getInitArgs(); - assert(initArgIdx >= 0 && initArgIdx < initArgs.size() && - "Unexpected 'initArgIdx' value"); - return getMakeTensorDescOp(initArgs[initArgIdx]); - } - LLVM_DEBUG(llvm::dbgs() - << "TODO: Unhandled non operation: " << base << "\n"); - return nullptr; - } - - if (defOp->getNumRegions() != 0) { - LLVM_DEBUG(llvm::dbgs() << "TODO: defOp with region: " << *defOp << "\n"); - return nullptr; - } - if (auto makeTensorDescOp = dyn_cast(defOp)) - return makeTensorDescOp; - - llvm_unreachable("TODO: Unhandled defOp kind"); - return nullptr; - } - - bool isModifiedInLoop(scf::ForOp forOp, BlockArgument &blockArg) const { - unsigned argNo = blockArg.getArgNumber(); - unsigned numIVs = forOp.getNumInductionVars(); - int initArgIdx = blockArg.getArgNumber() - numIVs; - Value yieldedVal = forOp.getYieldedValues()[initArgIdx]; - return (yieldedVal != blockArg); - } - // Create a new block pointer if a suitable one doesn't already exist. // Otherwise, return the existing one. The function takes the base, shape, // strides, offsets, sizes of the block pointer to create/lookup and its // tensor element type (to ensure the block pointer has the tensor layout). Value findOrCreateMakeTensorPtr(Location loc, Value base, ValueRange shape, ValueRange strides, ValueRange offsets, - ArrayRef sizes, - RankedTensorType tensorType, - OpBuilder &builder) { - Block *block = builder.getInsertionBlock(); - const Block::iterator insertPoint = builder.getInsertionPoint(); - auto ptrType = tt::PointerType::get( - tensorType, tt::TritonGEN::TritonGENMemorySpace::kCrossWorkgroup); - - auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) { - if (auto makeTensorPtrOp = dyn_cast(op)) { - triton::PointerType resType = makeTensorPtrOp.getResult().getType(); - auto tensorType = cast(resType.getPointeeType()); - auto sameShape = [](ArrayRef arr1, ArrayRef arr2) { - for (auto [dim1, dim2] : llvm::zip(arr1, arr2)) { - if (dim1 != dim2) - return false; - } - return true; - }; - - return makeTensorPtrOp.getType() == ptrType && - makeTensorPtrOp.getBase() == base && - makeTensorPtrOp.getShape() == shape && - makeTensorPtrOp.getStrides() == strides && - makeTensorPtrOp.getOffsets() == offsets && - sameShape(tensorType.getShape(), sizes); - } - return false; - }); - - auto makeTensorPtrOp = [&]() { - Value makeTensorPtr = builder.create( - loc, base, shape, strides, offsets, sizes, - builder.getDenseI32ArrayAttr({1, 0})); - makeTensorPtr.setType(ptrType); - return makeTensorPtr; - }; - - return (it != insertPoint) ? cast(*it) - : makeTensorPtrOp(); - } - - // Create a new block pointer if a suitable one doesn't already exist. - // Otherwise, return the existing one. The function takes the base, shape, - // strides, offsets, sizes of the block pointer to create/lookup and its - // tensor element type (to ensure the block pointer has the tensor layout). - Value findOrCreateMakeTensorPtrTmp(Location loc, Value base, ValueRange shape, - ValueRange strides, ValueRange offsets, - ArrayRef sizes, - OpBuilder &builder) { + ArrayRef sizes, OpBuilder &builder) { Block *block = builder.getInsertionBlock(); const Block::iterator insertPoint = builder.getInsertionPoint(); auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) { @@ -245,7 +150,7 @@ struct TritonIntelTensorDescToBlockPointer sizes.push_back(static_cast(size)); } - Value tensorPtr = findOrCreateMakeTensorPtrTmp( + Value tensorPtr = findOrCreateMakeTensorPtr( loc, op.getBase(), shapes, strides, offsets, sizes, builder); LLVM_DEBUG({ llvm::dbgs() << "With:\n"; @@ -276,81 +181,6 @@ struct TritonIntelTensorDescToBlockPointer return success(); } - template ::value, - bool> = true> - LogicalResult rewriteDescriptorLoadOrStoreOpOld(OpTy op) { - assert(op && "Expecting a valid operation"); - LLVM_DEBUG(llvm::dbgs() << "Rewriting: " << op << "\n"); - - OpBuilder builder(op); - Location loc = op.getLoc(); - TypedValue tDesc = op.getDesc(); - tt::TensorDescType tDescType = tDesc.getType(); - tt::MakeTensorDescOp makeTensorDescOp = getMakeTensorDescOp(tDesc); - - if (!makeTensorDescOp) { - LLVM_DEBUG(llvm::dbgs() - << "could not find tt.make_tensor_descriptor defining: " - << tDesc << "\n"); - return failure(); - } - - LLVM_DEBUG(llvm::dbgs() << "which has tdesc: " << makeTensorDescOp << "\n"); - - // Create a new block pointer if a suitable one doesn't already exist. - SmallVector shapes, strides, offsets; - SmallVector sizes; - for (const auto [shape, stride, offset, size] : - llvm::zip(makeTensorDescOp.getShape(), makeTensorDescOp.getStrides(), - op.getIndices(), tDescType.getBlockType().getShape())) { - shapes.push_back(findOrCreateCast( - loc, shape, builder.getIntegerType(shapeAndStridesBitwidth), - builder)); - strides.push_back(findOrCreateCast( - loc, stride, builder.getIntegerType(shapeAndStridesBitwidth), - builder)); - offsets.push_back(findOrCreateCast( - loc, offset, builder.getIntegerType(offsetBitwidth), builder)); - sizes.push_back(static_cast(size)); - } - - constexpr bool isLoad = std::is_same_v; - RankedTensorType tensorType; - if constexpr (isLoad) - tensorType = op.getResult().getType(); - else - tensorType = op.getSrc().getType(); - - Value makeTensorPtrOp = - findOrCreateMakeTensorPtr(loc, makeTensorDescOp.getBase(), shapes, - strides, offsets, sizes, tensorType, builder); - - LLVM_DEBUG({ - llvm::dbgs() << "With:\n"; - llvm::dbgs().indent(2) << makeTensorPtrOp << "\n"; - }); - - if constexpr (isLoad) { - auto loadOp = builder.createOrFold( - loc, makeTensorPtrOp, op.getCache(), op.getEvict(), - /*volatile*/ false); - LLVM_DEBUG(llvm::dbgs().indent(2) << loadOp << "\n"); - op.replaceAllUsesWith(loadOp); - } else { - [[maybe_unused]] auto storeOp = builder.createOrFold( - loc, makeTensorPtrOp, op.getSrc(), tt::CacheModifier::NONE, - tt::EvictionPolicy::NORMAL); - LLVM_DEBUG(llvm::dbgs().indent(2) << storeOp << "\n"); - } - - cleanUp.insert(op); - cleanUp.insert(makeTensorDescOp); - - return success(); - } - template ::value, diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index c6063b9d66..eab13c19eb 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -695,7 +695,7 @@ void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { // Recursively update the operands in a chain of AdvanceOps, after setting the // pointer operand of the first one. static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp, - Value data) { + Value dataToStore) { OpBuilder rewriter(advanceOp); auto newAdvanceOp = rewriter.create(advanceOp.getLoc(), makeTensorPtrOp.getType(), @@ -704,12 +704,10 @@ static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp, SmallVector advanceOpUsers(advanceOp->getUsers()); for (Operation *user : advanceOpUsers) { if (auto storeOp = dyn_cast(user)) { - // Update the StoreOp operands. storeOp.setOperand(0, newAdvanceOp); - storeOp.setOperand(1, data); - } else if (auto nextAdvanceOp = dyn_cast(user)) { - // Recursive call to handle the next AdvanceOp in the chain. - updateAdvanceOpChain(nextAdvanceOp, makeTensorPtrOp, data); + storeOp.setOperand(1, dataToStore); + } else if (auto advanceOp = dyn_cast(user)) { + updateAdvanceOpChain(advanceOp, makeTensorPtrOp, dataToStore); } } } @@ -790,26 +788,17 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { makeTensorPtrOp.getShape(), makeTensorPtrOp.getStrides(), makeTensorPtrOp.getOffsets(), makeTensorPtrOp.getOrderAttr()); -#if 1 // Update the store operation with the new layout. SmallVector makeTensorPtrOpUsers(makeTensorPtrOp->getUsers()); + auto dataToStore = getValueAs(value, encoding); for (Operation *user : makeTensorPtrOpUsers) { if (auto storeOp = dyn_cast(user)) { storeOp.setOperand(0, newMakeTensorPtrOp); - storeOp.setOperand(1, getValueAs(value, encoding)); + storeOp.setOperand(1, dataToStore); } else if (auto advanceOp = dyn_cast(user)) { - updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, - getValueAs(value, encoding)); + updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, dataToStore); } } -#else - // The encoding of the StoreOp is updated with the new operands: - // - the Ptr created by the MakeTensorPtrOp with the new data type - // - the forwarded DPAS encoding. - Value newOperand = getValueAs(value, encoding); - storeOp.setOperand(0, newMakeTensorPtrOp); - storeOp.setOperand(1, newOperand); -#endif // If the DPAS encoding is forwarded, we do not need the // convertOp anymore if the convertOp was only used by the From f6ed66a830ca1c9516e93a50d709985826970dd4 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 16 May 2025 15:49:23 +0000 Subject: [PATCH 07/16] WIP: TensorDescToBlockPtr updates Signed-off-by: Tiotto, Ettore --- .../RemoveLayoutConversions.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index eab13c19eb..24a09dceec 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -733,6 +733,9 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { "MakeTensorPtrOp should be the only op that creates a tensor pointer"); auto makeTensorPtrOp = cast(defOp); +// llvm::errs() << "storeOp: " << storeOp << "\n"; +// llvm::errs() << "makeTensorPtrOp: " << makeTensorPtrOp << "\n"; + // DPAS encoding have to be propagated if conversion from a DPAS layout to // another layout has been done before. auto convertOp = storeOp.getValue().getDefiningOp(); @@ -790,8 +793,14 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { // Update the store operation with the new layout. SmallVector makeTensorPtrOpUsers(makeTensorPtrOp->getUsers()); - auto dataToStore = getValueAs(value, encoding); + Value dataToStore = getValueAs(value, encoding); + Block *storeBB = storeOp->getBlock(); +// llvm::errs() << "dataToStore: " << dataToStore << "\n"; for (Operation *user : makeTensorPtrOpUsers) { + Block *userBB = user->getBlock(); + if (storeBB != userBB) + continue; + if (auto storeOp = dyn_cast(user)) { storeOp.setOperand(0, newMakeTensorPtrOp); storeOp.setOperand(1, dataToStore); @@ -800,6 +809,10 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { } } +// auto mod = makeTensorPtrOp->getParentOfType(); +// llvm::errs() << "Module after rewriting store:\n"; mod.dump(); llvm::errs() << "\n"; +// assert(succeeded(verify(mod)) && "Module verification failed"); + // If the DPAS encoding is forwarded, we do not need the // convertOp anymore if the convertOp was only used by the // storeOp. Same for the initial MakeTensorPtrOp, if it was From cb4bb2efca0ed9526daa9a56313d441fb492da30 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 20 May 2025 13:06:13 +0000 Subject: [PATCH 08/16] WIP: TensorDescToBlockPtr updates Signed-off-by: Tiotto, Ettore --- .../RemoveLayoutConversions.cpp | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 24a09dceec..c90d27de92 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -733,9 +733,6 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { "MakeTensorPtrOp should be the only op that creates a tensor pointer"); auto makeTensorPtrOp = cast(defOp); -// llvm::errs() << "storeOp: " << storeOp << "\n"; -// llvm::errs() << "makeTensorPtrOp: " << makeTensorPtrOp << "\n"; - // DPAS encoding have to be propagated if conversion from a DPAS layout to // another layout has been done before. auto convertOp = storeOp.getValue().getDefiningOp(); @@ -795,10 +792,9 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { SmallVector makeTensorPtrOpUsers(makeTensorPtrOp->getUsers()); Value dataToStore = getValueAs(value, encoding); Block *storeBB = storeOp->getBlock(); -// llvm::errs() << "dataToStore: " << dataToStore << "\n"; for (Operation *user : makeTensorPtrOpUsers) { Block *userBB = user->getBlock(); - if (storeBB != userBB) + if (storeBB != userBB) continue; if (auto storeOp = dyn_cast(user)) { @@ -809,10 +805,6 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { } } -// auto mod = makeTensorPtrOp->getParentOfType(); -// llvm::errs() << "Module after rewriting store:\n"; mod.dump(); llvm::errs() << "\n"; -// assert(succeeded(verify(mod)) && "Module verification failed"); - // If the DPAS encoding is forwarded, we do not need the // convertOp anymore if the convertOp was only used by the // storeOp. Same for the initial MakeTensorPtrOp, if it was From e5f74b895d7dd5d82eb502473d61a5cd6caeab91 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 22 May 2025 15:41:41 +0000 Subject: [PATCH 09/16] Address code review comments Signed-off-by: Tiotto, Ettore --- test/Triton/Intel/TensorDescToBlockPointer/basic.mlir | 2 +- .../Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir b/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir index c6b8717eb2..f3b9e1f020 100644 --- a/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir +++ b/test/Triton/Intel/TensorDescToBlockPointer/basic.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -triton-intel-tdesc-to-block-pointer | FileCheck %s -module attributes {"ttg.num-warps" = 4 : i32} { +module { tt.func public @test_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) { %c1_i64 = arith.constant 1 : i64 %c64_i32 = arith.constant 64 : i32 diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 0544ccee42..102c48b069 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -162,11 +162,9 @@ struct TritonIntelTensorDescToBlockPointer for (Operation *user : tensorPtr.getUsers()) { if (auto forOp = dyn_cast(user)) { - for (auto it : + for (auto [initArg, rgnInitArg, loopRes, yieldVal] : llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs(), forOp.getResults(), forOp.getYieldedValues())) { - Value initArg = std::get<0>(it), rgnInitArg = std::get<1>(it), - loopRes = std::get<2>(it), yieldVal = std::get<3>(it); assert(rgnInitArg.getType() == loopRes.getType() && rgnInitArg.getType() == yieldVal.getType() && "Type mismatch"); if (rgnInitArg.getType() != initArg.getType()) { From 7fcbe40cacbce04279c17a76f32062f38c621444 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 22 May 2025 16:43:00 +0000 Subject: [PATCH 10/16] Add unit test Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/combine.mlir | 46 ++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 131de33d98..1fa663e204 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -2472,3 +2472,49 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup tt.return } } + +// ----- + +// COM: Test that the DPAS layout is propagated to the store operation in the presence of an advance operation updating its base pointer. +// CHECK-NOT: #ttg.blocked<{.*}> +// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> +#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> +#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { + // CHECK-LABEL: matmul_kernel_with_block_pointers + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c0_i64 = arith.constant 0 : i64 + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #blocked1> + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + %18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr>, !tt.ptr>) : i32 { + // CHECK-NOT: ttg.convert_layout + %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major" } : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %36 = ttg.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas> + %30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> + %31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> + %32 = tt.dot %30, %31, %36, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas> + %33 = tt.advance %arg11, [%c0_i32, %c32_i32] : > + %34 = tt.advance %arg12, [%c32_i32, %c0_i32] : > + %35 = ttg.convert_layout %32 : tensor<64x256xf32, #dpas> -> tensor<64x256xf32, #blocked1> + scf.yield %35, %33, %34 : tensor<64x256xf32, #blocked1>, !tt.ptr>, !tt.ptr> + } + %24 = arith.truncf %23#0 : tensor<64x256xf32, #blocked1> to tensor<64x256xf16, #blocked1> + // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > + // CHECK: [[PTR2:%.*]] = tt.advance [[PTR1]], {{.*}} : > + // CHECK: tt.store [[PTR2]], {{.*}} {boundaryCheck = array} : !tt.ptr> + %27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %newptr = tt.advance %27, [%c32_i32, %c32_i32] : > + tt.store %newptr, %24 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} From 567d82ffb24471ff049ae80816f69bd142702a74 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 22 May 2025 18:27:28 +0000 Subject: [PATCH 11/16] Split unrelated changes in a separate PR Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/combine.mlir | 46 ------------- .../RemoveLayoutConversions.cpp | 69 +++++-------------- 2 files changed, 17 insertions(+), 98 deletions(-) diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 1fa663e204..131de33d98 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -2472,49 +2472,3 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup tt.return } } - -// ----- - -// COM: Test that the DPAS layout is propagated to the store operation in the presence of an advance operation updating its base pointer. -// CHECK-NOT: #ttg.blocked<{.*}> -// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> -#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> -#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { - // CHECK-LABEL: matmul_kernel_with_block_pointers - tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { - %c1_i64 = arith.constant 1 : i64 - %c0_i32 = arith.constant 0 : i32 - %c0_i64 = arith.constant 0 : i64 - %c32_i32 = arith.constant 32 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #blocked1> - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> - %18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > - %22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > - %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr>, !tt.ptr>) : i32 { - // CHECK-NOT: ttg.convert_layout - %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major" } : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> - %36 = ttg.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas> - %30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> - %31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> - %32 = tt.dot %30, %31, %36, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas> - %33 = tt.advance %arg11, [%c0_i32, %c32_i32] : > - %34 = tt.advance %arg12, [%c32_i32, %c0_i32] : > - %35 = ttg.convert_layout %32 : tensor<64x256xf32, #dpas> -> tensor<64x256xf32, #blocked1> - scf.yield %35, %33, %34 : tensor<64x256xf32, #blocked1>, !tt.ptr>, !tt.ptr> - } - %24 = arith.truncf %23#0 : tensor<64x256xf32, #blocked1> to tensor<64x256xf16, #blocked1> - // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > - // CHECK: [[PTR2:%.*]] = tt.advance [[PTR1]], {{.*}} : > - // CHECK: tt.store [[PTR2]], {{.*}} {boundaryCheck = array} : !tt.ptr> - %27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > - %newptr = tt.advance %27, [%c32_i32, %c32_i32] : > - tt.store %newptr, %24 {boundaryCheck = array} : !tt.ptr> - tt.return - } -} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index c90d27de92..4a5f93e08c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -692,26 +692,6 @@ void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { assertOp->setOperand(0, newOperand); } -// Recursively update the operands in a chain of AdvanceOps, after setting the -// pointer operand of the first one. -static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp, - Value dataToStore) { - OpBuilder rewriter(advanceOp); - auto newAdvanceOp = - rewriter.create(advanceOp.getLoc(), makeTensorPtrOp.getType(), - makeTensorPtrOp, advanceOp.getOffsets()); - - SmallVector advanceOpUsers(advanceOp->getUsers()); - for (Operation *user : advanceOpUsers) { - if (auto storeOp = dyn_cast(user)) { - storeOp.setOperand(0, newAdvanceOp); - storeOp.setOperand(1, dataToStore); - } else if (auto advanceOp = dyn_cast(user)) { - updateAdvanceOpChain(advanceOp, makeTensorPtrOp, dataToStore); - } - } -} - bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { // Disable 2D block store on LTS. if (!storeOp->getParentOfType()->hasAttr( @@ -725,16 +705,13 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { if (!isTensorPointerType(ptr.getType())) return false; - // Locate the operation that created the block pointer. - Operation *defOp = ptr.getDefiningOp(); - while (auto advanceOp = dyn_cast(defOp)) - defOp = advanceOp.getPtr().getDefiningOp(); - assert(isa(defOp) && - "MakeTensorPtrOp should be the only op that creates a tensor pointer"); - auto makeTensorPtrOp = cast(defOp); + // 2D block store are preceeded by a MakeTensorPtrOp + auto makeTensorPtrOp = ptr.getDefiningOp(); + if (!makeTensorPtrOp) + return false; - // DPAS encoding have to be propagated if conversion from a DPAS layout to - // another layout has been done before. + // DPAS encoding have to be propagate if conversion from DPAS to + // other has been done before. auto convertOp = storeOp.getValue().getDefiningOp(); PointerType newPtrType; Attribute encoding; @@ -781,29 +758,21 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) { encoding = convertOpSrcType.getEncoding(); } - // Create a new MakeTensorPtrOp with the new layout. + // We create a new MakeTensorPtrOp with the new data type. OpBuilder rewriter(makeTensorPtrOp); - Value newMakeTensorPtrOp = rewriter.create( + Value newStorePtr = rewriter.create( makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(), makeTensorPtrOp.getShape(), makeTensorPtrOp.getStrides(), - makeTensorPtrOp.getOffsets(), makeTensorPtrOp.getOrderAttr()); - - // Update the store operation with the new layout. - SmallVector makeTensorPtrOpUsers(makeTensorPtrOp->getUsers()); - Value dataToStore = getValueAs(value, encoding); - Block *storeBB = storeOp->getBlock(); - for (Operation *user : makeTensorPtrOpUsers) { - Block *userBB = user->getBlock(); - if (storeBB != userBB) - continue; + makeTensorPtrOp.getOffsets(), rewriter.getDenseI32ArrayAttr({1, 0})); - if (auto storeOp = dyn_cast(user)) { - storeOp.setOperand(0, newMakeTensorPtrOp); - storeOp.setOperand(1, dataToStore); - } else if (auto advanceOp = dyn_cast(user)) { - updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, dataToStore); - } - } + // The encoding of the StoreOp is updated with the new + // operands: + // - the Ptr created by the MakeTensorPtrOp with the new data + // type + // - the forwarded DPAS encoding. + Value newOperand = getValueAs(value, encoding); + storeOp.setOperand(0, newStorePtr); + storeOp.setOperand(1, newOperand); // If the DPAS encoding is forwarded, we do not need the // convertOp anymore if the convertOp was only used by the @@ -1638,7 +1607,6 @@ class TritonIntelGPURemoveLayoutConversionsPass LLVM_DEBUG({ DBGS() << "Module after propagating layouts forward:\n"; m.dump(); - assert(succeeded(verify(m)) && "Module verification failed"); }); cleanupConvertOps(); @@ -1649,7 +1617,6 @@ class TritonIntelGPURemoveLayoutConversionsPass LLVM_DEBUG({ DBGS() << "Module after backward remat:\n"; m.dump(); - assert(succeeded(verify(m)) && "Module verification failed"); }); // Cleanup dummy converts created during backward remat. @@ -1661,7 +1628,6 @@ class TritonIntelGPURemoveLayoutConversionsPass LLVM_DEBUG({ DBGS() << "Module after hoisting converts:\n"; m.dump(); - assert(succeeded(verify(m)) && "Module verification failed"); }); // 4. Apply clean up patterns to remove remove dead convert and dead code @@ -1677,7 +1643,6 @@ class TritonIntelGPURemoveLayoutConversionsPass LLVM_DEBUG({ DBGS() << "Module after final cleanups:\n"; m.dump(); - assert(succeeded(verify(m)) && "Module verification failed"); }); } }; From 93bec90398e92bb060825e8456f509852352ec19 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 26 May 2025 21:51:50 +0000 Subject: [PATCH 12/16] Add test 'load_in_while_loop' and fix it Signed-off-by: Tiotto, Ettore --- .../Intel/TensorDescToBlockPointer/loop.mlir | 28 +++++++++++ .../Transforms/TensorDescToBlockPointer.cpp | 49 +++++++++++-------- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir index 1748a6d6c4..10f11bc1fb 100644 --- a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir +++ b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir @@ -208,4 +208,32 @@ module { // CHECK: tt.return // CHECK: } + tt.func public @load_in_while_loop(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %0 = tt.get_program_id x : i32 + %3 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%c1_i64, %c1_i64] : , > + %5 = scf.while (%arg3 = %3) : (!tt.tensordesc>) -> (!tt.tensordesc>) { + %6 = arith.cmpi slt, %c0_i32, %arg2 : i32 + scf.condition(%6) %arg3 : !tt.tensordesc> + } do { + ^bb0(%arg3: !tt.tensordesc>): + %12 = tt.descriptor_load %arg3[%0, %c0_i32] : !tt.tensordesc> -> tensor<8x128xf32> + scf.yield %arg3 : !tt.tensordesc> + } + tt.return + } + // CHECK: tt.func public @load_in_while_loop({{.*}}) { + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK-NOT: tt.descriptor_load + // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}} : + // CHECK: scf.while ([[ARG3:%.*]] = [[PTR]]) : (!tt.ptr>) -> !tt.ptr> { + // CHECK: scf.condition({{.*}}) [[ARG3]] : !tt.ptr> + // CHECK: } do { + // CHECK: ^bb0([[ARG4:%.*]]: !tt.ptr>): + // CHECK: [[PTR1:%.*]] = tt.advance [[ARG4]], {{.*}} : + // CHECK: tt.load [[PTR1]] : !tt.ptr> + // CHECK: scf.yield [[ARG4]] : !tt.ptr> + // CHECK: } + } diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 102c48b069..5abaf49ed2 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -3,6 +3,7 @@ #include "intel/include/Utils/Utility.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "llvm/ADT/STLExtras.h" @@ -87,9 +88,10 @@ struct TritonIntelTensorDescToBlockPointer // Otherwise, return the existing one. The function takes the base, shape, // strides, offsets, sizes of the block pointer to create/lookup and its // tensor element type (to ensure the block pointer has the tensor layout). - Value findOrCreateMakeTensorPtr(Location loc, Value base, ValueRange shape, - ValueRange strides, ValueRange offsets, - ArrayRef sizes, OpBuilder &builder) { + tt::MakeTensorPtrOp + findOrCreateMakeTensorPtr(Location loc, Value base, ValueRange shape, + ValueRange strides, ValueRange offsets, + ArrayRef sizes, OpBuilder &builder) { Block *block = builder.getInsertionBlock(); const Block::iterator insertPoint = builder.getInsertionPoint(); auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) { @@ -114,7 +116,7 @@ struct TritonIntelTensorDescToBlockPointer }); auto makeTensorPtrOp = [&]() { - Value makeTensorPtr = builder.create( + auto makeTensorPtr = builder.create( loc, base, shape, strides, offsets, sizes, builder.getDenseI32ArrayAttr({1, 0})); return makeTensorPtr; @@ -124,6 +126,24 @@ struct TritonIntelTensorDescToBlockPointer : makeTensorPtrOp(); } + void propagateToLoops(tt::MakeTensorPtrOp op) { + for (Operation *user : op->getUsers()) { + if (auto loopOp = dyn_cast(user)) { + for (auto [initArg, rgnInitArg, loopRes, yieldVal] : + llvm::zip(loopOp.getInits(), loopOp.getRegionIterArgs(), + loopOp->getResults(), loopOp.getYieldedValues())) { + assert(rgnInitArg.getType() == loopRes.getType() && + rgnInitArg.getType() == yieldVal.getType() && "Type mismatch"); + if (rgnInitArg.getType() != initArg.getType()) { + rgnInitArg.setType(initArg.getType()); + loopRes.setType(initArg.getType()); + yieldVal.setType(initArg.getType()); + } + } + } + } + } + LogicalResult rewriteMakeTensorDescriptorOp(tt::MakeTensorDescOp op) { assert(op && "Expecting a valid operation"); LLVM_DEBUG(llvm::dbgs() << "Rewriting: " << op << "\n"); @@ -150,31 +170,18 @@ struct TritonIntelTensorDescToBlockPointer sizes.push_back(static_cast(size)); } - Value tensorPtr = findOrCreateMakeTensorPtr( + auto tensorPtr = findOrCreateMakeTensorPtr( loc, op.getBase(), shapes, strides, offsets, sizes, builder); LLVM_DEBUG({ llvm::dbgs() << "With:\n"; llvm::dbgs().indent(2) << tensorPtr << "\n"; }); - op.replaceAllUsesWith(tensorPtr); + op->replaceAllUsesWith(tensorPtr); cleanUp.insert(op); - for (Operation *user : tensorPtr.getUsers()) { - if (auto forOp = dyn_cast(user)) { - for (auto [initArg, rgnInitArg, loopRes, yieldVal] : - llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs(), - forOp.getResults(), forOp.getYieldedValues())) { - assert(rgnInitArg.getType() == loopRes.getType() && - rgnInitArg.getType() == yieldVal.getType() && "Type mismatch"); - if (rgnInitArg.getType() != initArg.getType()) { - rgnInitArg.setType(initArg.getType()); - loopRes.setType(initArg.getType()); - yieldVal.setType(initArg.getType()); - } - } - } - } + // Propagate the `tensorPtr` type to loops init args, etc... + propagateToLoops(tensorPtr); return success(); } From f61f753486cdf3ba72dda6b1ad975e6c466b2fc9 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 26 May 2025 21:53:52 +0000 Subject: [PATCH 13/16] Fix precommit Signed-off-by: Tiotto, Ettore --- test/Triton/Intel/TensorDescToBlockPointer/loop.mlir | 4 ++-- .../Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir index 10f11bc1fb..150bfb9907 100644 --- a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir +++ b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir @@ -209,7 +209,7 @@ module { // CHECK: } tt.func public @load_in_while_loop(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) { - %c0_i32 = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %0 = tt.get_program_id x : i32 %3 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%c1_i64, %c1_i64] : , > @@ -229,7 +229,7 @@ module { // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}} : // CHECK: scf.while ([[ARG3:%.*]] = [[PTR]]) : (!tt.ptr>) -> !tt.ptr> { // CHECK: scf.condition({{.*}}) [[ARG3]] : !tt.ptr> - // CHECK: } do { + // CHECK: } do { // CHECK: ^bb0([[ARG4:%.*]]: !tt.ptr>): // CHECK: [[PTR1:%.*]] = tt.advance [[ARG4]], {{.*}} : // CHECK: tt.load [[PTR1]] : !tt.ptr> diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 5abaf49ed2..d894651b39 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -139,7 +139,7 @@ struct TritonIntelTensorDescToBlockPointer loopRes.setType(initArg.getType()); yieldVal.setType(initArg.getType()); } - } + } } } } From 27a78c1d83654683e31682a19c1865a718862d54 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 27 May 2025 16:18:32 +0000 Subject: [PATCH 14/16] Add test 'while_uses_tdesc_yielded_by_for_loop' and fix it Signed-off-by: Tiotto, Ettore --- .../Intel/TensorDescToBlockPointer/loop.mlir | 43 ++++++++++++++++++- .../Transforms/TensorDescToBlockPointer.cpp | 39 ++++++++++------- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir index 150bfb9907..6bb28f958b 100644 --- a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir +++ b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir @@ -208,6 +208,7 @@ module { // CHECK: tt.return // CHECK: } + // COM: While loop contains a descriptor load operation. tt.func public @load_in_while_loop(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 @@ -226,8 +227,46 @@ module { // CHECK: tt.func public @load_in_while_loop({{.*}}) { // CHECK-NOT: tt.make_tensor_descriptor // CHECK-NOT: tt.descriptor_load - // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}} : - // CHECK: scf.while ([[ARG3:%.*]] = [[PTR]]) : (!tt.ptr>) -> !tt.ptr> { + // CHECK: [[TENSOR_PTR:%.*]] = tt.make_tensor_ptr {{.*}} : + // CHECK: scf.while ([[ARG3:%.*]] = [[TENSOR_PTR]]) : (!tt.ptr>) -> !tt.ptr> { + // CHECK: scf.condition({{.*}}) [[ARG3]] : !tt.ptr> + // CHECK: } do { + // CHECK: ^bb0([[ARG4:%.*]]: !tt.ptr>): + // CHECK: [[PTR1:%.*]] = tt.advance [[ARG4]], {{.*}} : + // CHECK: tt.load [[PTR1]] : !tt.ptr> + // CHECK: scf.yield [[ARG4]] : !tt.ptr> + // CHECK: } + + // COM: For loop yields a tensor descriptor used by a while loop. + tt.func public @while_uses_tdesc_yielded_by_for_loop(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) { + %c1_i64 = arith.constant 1 : i64 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %2 = arith.extsi %arg2 : i32 to i64 + %3 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%2, %c1_i64] : , > + %4 = scf.for %arg3 = %c0_i32 to %arg2 step %c128_i32 iter_args(%arg4 = %3) -> (!tt.tensordesc>) : i32 { + scf.yield %arg4 : !tt.tensordesc> + } + %5 = scf.while (%arg3 = %4) : (!tt.tensordesc>) -> (!tt.tensordesc>) { + %6 = arith.cmpi slt, %0, %arg2 : i32 + scf.condition(%6) %arg3 : !tt.tensordesc> + } do { + ^bb0(%arg3: !tt.tensordesc>): + %12 = tt.descriptor_load %arg3[%c8_i32, %c8_i32] : !tt.tensordesc> -> tensor<8x128xf32> + scf.yield %arg3 : !tt.tensordesc> + } + tt.return + } + // CHECK: tt.func public @while_uses_tdesc_yielded_by_for_loop({{.*}}) { + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK-NOT: tt.descriptor_load + // CHECK: [[TENSOR_PTR:%.*]] = tt.make_tensor_ptr {{.*}} : + // CHECK: [[FOR_RES:%.+]] = scf.for [[IV:%.+]] = {{.*}} iter_args([[ARG3:%.*]] = [[TENSOR_PTR]]) -> (!tt.ptr>) : i32 { + // CHECK: scf.yield {{.*}} : !tt.ptr> + // CHECK: } + // CHECK: scf.while ([[ARG3:%.*]] = [[FOR_RES]]) : (!tt.ptr>) -> !tt.ptr> { // CHECK: scf.condition({{.*}}) [[ARG3]] : !tt.ptr> // CHECK: } do { // CHECK: ^bb0([[ARG4:%.*]]: !tt.ptr>): diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index d894651b39..5b40bbe21d 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -126,21 +126,28 @@ struct TritonIntelTensorDescToBlockPointer : makeTensorPtrOp(); } - void propagateToLoops(tt::MakeTensorPtrOp op) { - for (Operation *user : op->getUsers()) { - if (auto loopOp = dyn_cast(user)) { - for (auto [initArg, rgnInitArg, loopRes, yieldVal] : - llvm::zip(loopOp.getInits(), loopOp.getRegionIterArgs(), - loopOp->getResults(), loopOp.getYieldedValues())) { - assert(rgnInitArg.getType() == loopRes.getType() && - rgnInitArg.getType() == yieldVal.getType() && "Type mismatch"); - if (rgnInitArg.getType() != initArg.getType()) { - rgnInitArg.setType(initArg.getType()); - loopRes.setType(initArg.getType()); - yieldVal.setType(initArg.getType()); - } + void propagateToLoops(Operation *op) { + if (auto loopOp = dyn_cast(op)) { + bool updated = false; + for (auto [initArg, rgnInitArg, loopRes, yieldVal] : + llvm::zip(loopOp.getInits(), loopOp.getRegionIterArgs(), + loopOp->getResults(), loopOp.getYieldedValues())) { + Type initArgType = initArg.getType(); + Type rgnInitArgType = rgnInitArg.getType(); + assert(rgnInitArgType == loopRes.getType() && + rgnInitArgType == yieldVal.getType() && "Type mismatch"); + if (rgnInitArgType != initArgType) { + rgnInitArg.setType(initArgType); + yieldVal.setType(initArgType); + loopRes.setType(initArgType); + updated = true; } } + if (!updated) + return; + + for (Operation *user : loopOp->getUsers()) + propagateToLoops(user); } } @@ -180,8 +187,10 @@ struct TritonIntelTensorDescToBlockPointer op->replaceAllUsesWith(tensorPtr); cleanUp.insert(op); - // Propagate the `tensorPtr` type to loops init args, etc... - propagateToLoops(tensorPtr); + // Propagate the `tensorPtr` type to loops init args, yielded values, + // results, ... (if necessary). + for (Operation *user : tensorPtr->getUsers()) + propagateToLoops(user); return success(); } From c55925f7bd27280971b8cdaeef48c82c6ed445fa Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 27 May 2025 18:14:38 +0000 Subject: [PATCH 15/16] Add test 'while_loop_with_if_stmt' and fix it Signed-off-by: Tiotto, Ettore --- .../Intel/TensorDescToBlockPointer/loop.mlir | 48 ++++++++++++++++++- .../Transforms/TensorDescToBlockPointer.cpp | 15 +++++- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir index 6bb28f958b..15f3810b78 100644 --- a/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir +++ b/test/Triton/Intel/TensorDescToBlockPointer/loop.mlir @@ -270,9 +270,53 @@ module { // CHECK: scf.condition({{.*}}) [[ARG3]] : !tt.ptr> // CHECK: } do { // CHECK: ^bb0([[ARG4:%.*]]: !tt.ptr>): - // CHECK: [[PTR1:%.*]] = tt.advance [[ARG4]], {{.*}} : - // CHECK: tt.load [[PTR1]] : !tt.ptr> + // CHECK: [[TENSOR_PTR1:%.*]] = tt.advance [[ARG4]], {{.*}} : + // CHECK: tt.load [[TENSOR_PTR1]] : !tt.ptr> // CHECK: scf.yield [[ARG4]] : !tt.ptr> // CHECK: } + // COM: While loop containing a if statement yielding a tensor descriptor used by a load operation. + tt.func public @while_loop_with_if_stmt(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) { + %c1_i64 = arith.constant 1 : i64 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %3 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%c1_i64, %c1_i64] : , > + %5 = scf.while (%arg3 = %3) : (!tt.tensordesc>) -> (!tt.tensordesc>) { + %6 = arith.cmpi slt, %0, %c128_i32 : i32 + scf.condition(%6) %arg3 : !tt.tensordesc> + } do { + ^bb0(%arg3: !tt.tensordesc>): + %7 = arith.cmpi eq, %0, %c0_i32 : i32 + %10 = arith.select %7, %c1_i64, %c1_i64 : i64 + %11 = scf.if %7 -> (!tt.tensordesc>) { + %16 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%c1_i64, %c1_i64] : , > + scf.yield %16 : !tt.tensordesc> + } else { + scf.yield %arg3 : !tt.tensordesc> + } + %12 = tt.descriptor_load %11[%c128_i32, %c128_i32] : !tt.tensordesc> -> tensor<8x128xf32> + scf.yield %11 : !tt.tensordesc> + } + tt.return + } + // CHECK: tt.func public @while_loop_with_if_stmt({{.*}}) { + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK-NOT: tt.descriptor_load + // CHECK: [[TENSOR_PTR:%.*]] = tt.make_tensor_ptr {{.*}} : + // CHECK: scf.while ([[ARG3:%.*]] = [[TENSOR_PTR]]) : (!tt.ptr>) -> !tt.ptr> { + // CHECK: scf.condition({{.*}}) [[ARG3]] : !tt.ptr> + // CHECK: } do { + // CHECK: ^bb0([[ARG4:%.*]]: !tt.ptr>): + // CHECK: [[IF_RES:%.*]] = scf.if {{.*}} -> (!tt.ptr>) { + // CHECK: [[TENSOR_PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : + // CHECK: scf.yield [[TENSOR_PTR1]] : !tt.ptr> + // CHECK: } else { + // CHECK: scf.yield [[ARG4]] : !tt.ptr> + // CHECK: } + // CHECK: [[TENSOR_PTR2:%.*]] = tt.advance [[IF_RES]], {{.*}} : + // CHECK: tt.load [[TENSOR_PTR2]] : !tt.ptr> + // CHECK: scf.yield [[IF_RES]] : !tt.ptr> + // CHECK: } + } diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 5b40bbe21d..2925387170 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -129,9 +129,9 @@ struct TritonIntelTensorDescToBlockPointer void propagateToLoops(Operation *op) { if (auto loopOp = dyn_cast(op)) { bool updated = false; - for (auto [initArg, rgnInitArg, loopRes, yieldVal] : + for (auto [initArg, rgnInitArg, yieldVal, loopRes] : llvm::zip(loopOp.getInits(), loopOp.getRegionIterArgs(), - loopOp->getResults(), loopOp.getYieldedValues())) { + loopOp.getYieldedValues(), loopOp->getResults())) { Type initArgType = initArg.getType(); Type rgnInitArgType = rgnInitArg.getType(); assert(rgnInitArgType == loopRes.getType() && @@ -146,6 +146,17 @@ struct TritonIntelTensorDescToBlockPointer if (!updated) return; + // For while loops we also need to update the "after" region arguments. + if (auto loopOp = dyn_cast(op)) { + for (auto [initArg, rgnAfterArg] : + llvm::zip(loopOp.getInits(), loopOp.getAfterArguments())) { + Type initArgType = initArg.getType(); + if (rgnAfterArg.getType() != initArgType) + rgnAfterArg.setType(initArgType); + } + } + + // Propagate the loop results to their users. for (Operation *user : loopOp->getUsers()) propagateToLoops(user); } From 9f41668e0686c1185d2b89ce0bec7422fc0661c2 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 27 May 2025 20:17:00 +0000 Subject: [PATCH 16/16] Fix precommit Signed-off-by: Tiotto, Ettore --- .../Triton/Transforms/TensorDescToBlockPointer.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index 1e544c07aa..5efc369b61 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -221,7 +221,7 @@ struct TritonIntelTensorDescToBlockPointer "Expecting a block ptr"); auto ptrType = cast(ptr.getType()); auto tensorType = cast(ptrType.getPointeeType()); - + ptr = builder.create(loc, ptr.getType(), ptr, op.getIndices()); @@ -229,14 +229,12 @@ struct TritonIntelTensorDescToBlockPointer for (size_t i = 0; i < tensorType.getRank(); ++i) boundaryCheck.push_back(i); -// for (size_t i = 0; i < makeTensorDescOp.getShape().size(); ++i) -// boundaryCheck.push_back(i); constexpr bool isLoad = std::is_same_v; if constexpr (isLoad) { - auto loadOp = builder.createOrFold(loc, ptr, boundaryCheck, - /*padding*/ std::nullopt, - op.getCache(), op.getEvict(), - /*volatile*/ false); + auto loadOp = builder.createOrFold( + loc, ptr, boundaryCheck, + /*padding*/ std::nullopt, op.getCache(), op.getEvict(), + /*volatile*/ false); LLVM_DEBUG(llvm::dbgs().indent(2) << loadOp << "\n"); op.replaceAllUsesWith(loadOp); } else {