From 3df1acf0b1e3c83096f1aeec786d56bdab1b49b1 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 17 Apr 2025 21:24:07 +0000 Subject: [PATCH 1/3] Fix issue #3805 Signed-off-by: Tiotto, Ettore --- .../TritonGPUToLLVM/ControlFlowOpToLLVM.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index 8541f9e877..2eddf5735d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -100,8 +101,19 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { opOffsetVal = b.i32_val(opOffset); } - promotedOperands.push_back(LLVM::getGlobalScratchPtr( - loc, rewriter, targetInfo, caller, opOffsetVal)); + Value globalScatchPtr = LLVM::getGlobalScratchPtr(loc, rewriter, targetInfo, + caller, opOffsetVal); + auto callee = cast(callOp.resolveCallable()); + auto lastArgType = + callee.getArguments()[callee.getNumArguments() - 1].getType(); + + if (lastArgType != globalScatchPtr.getType()) { + auto zeroOp = rewriter.create(loc, lastArgType); + promotedOperands.push_back(zeroOp); + return promotedOperands; + } + + promotedOperands.push_back(globalScatchPtr); return promotedOperands; } From 82fb6f50027418c7d4c1512640fa48cc45ebdcaa Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 22 Apr 2025 18:17:05 +0000 Subject: [PATCH 2/3] Improve test case, enable noinline functions in test_tensor_descriptor Signed-off-by: Tiotto, Ettore --- .../TritonGPUToLLVM/ControlFlowOpToLLVM.cpp | 16 ++-------- .../tritonintelgpu-rewrite-stack-ptr.mlir | 30 +++++++++---------- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index 2eddf5735d..8541f9e877 100644 --- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -1,4 +1,3 @@ -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -101,19 +100,8 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { opOffsetVal = b.i32_val(opOffset); } - Value globalScatchPtr = LLVM::getGlobalScratchPtr(loc, rewriter, targetInfo, - caller, opOffsetVal); - auto callee = cast(callOp.resolveCallable()); - auto lastArgType = - callee.getArguments()[callee.getNumArguments() - 1].getType(); - - if (lastArgType != globalScatchPtr.getType()) { - auto zeroOp = rewriter.create(loc, lastArgType); - promotedOperands.push_back(zeroOp); - return promotedOperands; - } - - promotedOperands.push_back(globalScatchPtr); + promotedOperands.push_back(LLVM::getGlobalScratchPtr( + loc, rewriter, targetInfo, caller, opOffsetVal)); return promotedOperands; } diff --git a/test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir b/test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir index 8ed535929b..b27b4061f7 100644 --- a/test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir +++ b/test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir @@ -1,19 +1,19 @@ // RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --tritonintelgpu-rewrite-stack-ptr | FileCheck %s -module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> - // CHECK-LABEL: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) - tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32} { + // CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + // CHECK-LABEL: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) + tt.func public @kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %0 = tt.load %arg0 : !tt.ptr %1 = tt.load %arg1 : !tt.ptr // CHECK: llvm.mlir.poison : !llvm.ptr<3> - // CHECK: llvm.call spir_funccc @noinline_simple_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %18, %arg2) - tt.call @noinline_simple_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) : (f32, f32, !tt.ptr) -> () + // CHECK: llvm.call spir_funccc @noinline_simple_fn(%8, %17, %arg2, %18, %arg2) + tt.call @noinline_simple_fn(%0, %1, %arg2) : (f32, f32, !tt.ptr) -> () tt.return } - // CHECK: llvm.func internal @noinline_simple_fn__fp32_fp32_Pfp32__(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>) - tt.func private @noinline_simple_fn__fp32_fp32_Pfp32__(%arg0: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg1: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg2: !tt.ptr {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 16 : i64}) attributes {noinline = true} { - %0 = arith.addf %arg0, %arg1 fastmath : f32 + // CHECK-LABEL: llvm.func internal @noinline_simple_fn(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>) + tt.func private @noinline_simple_fn(%arg0: f32, %arg1: f32, %arg2: !tt.ptr) attributes {noinline = true} { + %0 = arith.addf %arg0, %arg1 : f32 tt.store %arg2, %0 : !tt.ptr tt.return } @@ -25,19 +25,19 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory -module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 1280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { +module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 1280 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> // CHECK-LABEL: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>) - tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + tt.func public @kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %0 = tt.load %arg0 : !tt.ptr %1 = tt.load %arg1 : !tt.ptr - // CHECK: llvm.call spir_funccc @noinline_shared_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %arg3, %arg2) - tt.call @noinline_shared_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) {allocation.offset = 0 : i32} : (f32, f32, !tt.ptr) -> () + // CHECK: llvm.call spir_funccc @noinline_shared_fn(%8, %17, %arg2, %arg3, %arg2) + tt.call @noinline_shared_fn(%0, %1, %arg2) {allocation.offset = 0 : i32} : (f32, f32, !tt.ptr) -> () tt.return } - // CHECK: llvm.func internal @noinline_shared_fn__fp32_fp32_Pfp32__(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>) + // CHECK: llvm.func internal @noinline_shared_fn(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>) // CHECK: llvm.getelementptr %arg3[{{.*}}] - tt.func private @noinline_shared_fn__fp32_fp32_Pfp32__(%arg0: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg1: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg2: !tt.ptr {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 16 : i64}) attributes {noinline = true} { + tt.func private @noinline_shared_fn(%arg0: f32, %arg1: f32, %arg2: !tt.ptr) attributes {noinline = true} { %cst = arith.constant dense<16> : tensor<16x1xi32, #blocked> %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> From 53f16f37f7d91139ca442de6745522e3ebc4e573 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 22 Apr 2025 20:57:23 +0000 Subject: [PATCH 3/3] Create Intel version of CallOpConversion pattern Signed-off-by: Tiotto, Ettore --- .../test/unit/intel/test_tensor_descriptor.py | 2 +- .../tritonintelgpu-rewrite-stack-ptr.mlir | 4 +- .../ControlFlowOpToLLVM.cpp | 118 ++++++++++++++++++ 3 files changed, 121 insertions(+), 3 deletions(-) diff --git a/python/test/unit/intel/test_tensor_descriptor.py b/python/test/unit/intel/test_tensor_descriptor.py index 76b003c276..20c01b39e3 100644 --- a/python/test/unit/intel/test_tensor_descriptor.py +++ b/python/test/unit/intel/test_tensor_descriptor.py @@ -245,7 +245,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): torch.testing.assert_close(expect, actual) -@triton.jit(noinline=False) +@triton.jit(noinline=True) def tensor_descriptor_in_function_helper(out_ptr, in_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): in_desc = tl.make_tensor_descriptor( in_ptr, diff --git a/test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir b/test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir index b27b4061f7..afbb330774 100644 --- a/test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir +++ b/test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir @@ -12,7 +12,7 @@ module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-warps" = 1 tt.return } // CHECK-LABEL: llvm.func internal @noinline_simple_fn(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>) - tt.func private @noinline_simple_fn(%arg0: f32, %arg1: f32, %arg2: !tt.ptr) attributes {noinline = true} { + tt.func private @noinline_simple_fn(%arg0: f32, %arg1: f32, %arg2: !tt.ptr) attributes {noinline = true} { %0 = arith.addf %arg0, %arg1 : f32 tt.store %arg2, %0 : !tt.ptr tt.return @@ -25,7 +25,7 @@ module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-warps" = 1 #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory -module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 1280 : i32, "ttg.threads-per-warp" = 16 : i32} { +module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-warps" = 1 : i32, ttg.shared = 1280 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> // CHECK-LABEL: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>) tt.func public @kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp index 68d977983f..a086d7c96f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -1,5 +1,9 @@ #include "PatternTritonGPUOpToLLVM.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include namespace { @@ -16,12 +20,126 @@ struct FixCallCConv : public ConvertOpToLLVMPattern { } }; +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the last argument of the caller, which is the current stack pointer + // of shared memory and append it to the operands of the callOp. + auto loc = callOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + if (!caller->hasAttr("allocation.offset")) { + auto base = LLVM::getStackPointer(rewriter, caller); + promotedOperands.push_back(base); + } else { + auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp); + promotedOperands.push_back(base); + } + + auto opOffsetAttr = callOp->getAttrOfType( + "ttg.global_scratch_memory_offset"); + Value opOffsetVal; + if (opOffsetAttr) { + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + opOffsetVal = b.i32_val(opOffset); + } + + Value globalScratchPtr = LLVM::getGlobalScratchPtr( + loc, rewriter, targetInfo, caller, opOffsetVal); + auto callee = cast(callOp.resolveCallable()); + auto lastArgType = + callee.getArguments()[callee.getNumArguments() - 1].getType(); + if (lastArgType != globalScratchPtr.getType()) { + auto zeroOp = rewriter.create(loc, lastArgType); + promotedOperands.push_back(zeroOp); + return promotedOperands; + } + + promotedOperands.push_back(globalScratchPtr); + + return promotedOperands; + } + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + newCallOp.getProperties().setOpBundleSizes( + rewriter.getDenseI32ArrayAttr({})); + newCallOp.getProperties().setOperandSegmentSizes( + {static_cast(promotedOperands.size()), 0}); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } + const TargetInfoBase &targetInfo; +}; + } // namespace void mlir::triton::intel::populateControlFlowOpToLLVMPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit) { patterns.add(typeConverter); + // Overwrite the CallOpConversion pattern added by the call to + // populateControlFlowOpToLLVMPattern. + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); }