Skip to content

[tensor-descriptor]: Extend support when tensor descriptor created in control flow #4152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4384ad1
Ensure block ptr is created with the same layout as the descriptor_lo…
etiotto May 9, 2025
f0ce91c
Remove naked print and unnecessary headers
etiotto May 9, 2025
04f1c1d
Merge remote-tracking branch 'origin/main' into etiotto.tensor_desc_t…
etiotto May 14, 2025
3543e6e
WIP: TensorDescToBlockPtr updates
etiotto May 14, 2025
38ef6c3
WIP: RemoveLAuoyutConversion improvement for tt.advance operation
etiotto May 15, 2025
e4d5d7d
Merge remote-tracking branch 'origin/main' into etiotto.tensor_desc_t…
etiotto May 15, 2025
9eb16ec
WIP: TensorDescToBlockPtr updates
etiotto May 15, 2025
7f9bbc9
WIP: TensorDescToBlockPtr updates
etiotto May 15, 2025
f6ed66a
WIP: TensorDescToBlockPtr updates
etiotto May 16, 2025
cb4bb2e
WIP: TensorDescToBlockPtr updates
etiotto May 20, 2025
f6ce50a
Merge branch 'main' into etiotto.tensor_desc_to_block_ptr.1
etiotto May 20, 2025
b439d24
Merge remote-tracking branch 'origin/main' into etiotto.tensor_desc_t…
etiotto May 22, 2025
e5f74b8
Address code review comments
etiotto May 22, 2025
ff321f5
Merge remote-tracking branch 'origin/main' into etiotto.tensor_desc_t…
etiotto May 22, 2025
7fcbe40
Add unit test
etiotto May 22, 2025
567d82f
Split unrelated changes in a separate PR
etiotto May 22, 2025
75e5ef4
Merge branch 'main' into etiotto.tensor_desc_to_block_ptr.1
whitneywhtsang May 23, 2025
7be4b3f
Merge branch 'main' into etiotto.tensor_desc_to_block_ptr.1
etiotto May 26, 2025
463adf8
Merge remote-tracking branch 'origin/main' into etiotto.tensor_desc_t…
etiotto May 26, 2025
93bec90
Add test 'load_in_while_loop' and fix it
etiotto May 26, 2025
f61f753
Fix precommit
etiotto May 26, 2025
27a78c1
Add test 'while_uses_tdesc_yielded_by_for_loop' and fix it
etiotto May 27, 2025
c55925f
Add test 'while_loop_with_if_stmt' and fix it
etiotto May 27, 2025
c39325d
Merge remote-tracking branch 'origin/main' into etiotto.tensor_desc_t…
etiotto May 27, 2025
9f41668
Fix precommit
etiotto May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions test/Triton/Intel/TensorDescToBlockPointer/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
// RUN: triton-opt %s -triton-intel-tdesc-to-block-pointer | FileCheck %s

module {
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>

module attributes {"ttg.num-warps" = 4 : i32} {
tt.func public @test_load(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: i32) {
%c1_i64 = arith.constant 1 : i64
%c64_i32 = arith.constant 64 : i32
%c8_i32 = arith.constant 8 : i32
%0 = arith.extsi %arg2 : i32 to i64
%desc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f32>, <tensor<16x128xf32>>
%load = tt.descriptor_load %desc[%c8_i32, %c64_i32] : !tt.tensordesc<tensor<16x128xf32>> -> tensor<16x128xf32>
%desc1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f32>, <tensor<16x128xf32>>
%load1 = tt.descriptor_load %desc1[%c8_i32, %c64_i32] : !tt.tensordesc<tensor<16x128xf32>> -> tensor<16x128xf32>
%load2 = tt.descriptor_load %desc1[%c8_i32, %c64_i32] : !tt.tensordesc<tensor<16x128xf32>> -> tensor<16x128xf32, #blocked>
tt.return
}
// CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: tt.func public @test_load([[PARAM_0:%.+]]: !tt.ptr<f32>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) {
// CHECK-NOT: tt.make_tensor_descriptor
// CHECK-NOT: tt.descriptor_load
Expand All @@ -18,8 +22,10 @@ module {
// CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
// CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] : !tt.ptr<tensor<16x128xf32>>
// CHECK: [[TENSOR_PTR1:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
// CHECK: [[LOAD1:%.+]] = tt.load [[TENSOR_PTR1]] : !tt.ptr<tensor<16x128xf32>>
// CHECK: [[TENSOR_PTR2:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32, #[[$BLOCKED]]>>
// CHECK: [[LOAD2:%.+]] = tt.load [[TENSOR_PTR2]] : !tt.ptr<tensor<16x128xf32, #[[$BLOCKED]]>>
// CHECK: tt.return
// CHECK: }

Expand All @@ -28,9 +34,11 @@ module {
%c64_i32 = arith.constant 64 : i32
%c8_i32 = arith.constant 8 : i32
%cst = arith.constant dense<1.000000e+00> : tensor<16x128xf32>
%cst1 = arith.constant dense<1.000000e+00> : tensor<16x128xf32, #blocked>
%0 = arith.extsi %arg2 : i32 to i64
%desc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f32>, <tensor<16x128xf32>>
tt.descriptor_store %desc[%c8_i32, %c64_i32], %cst : !tt.tensordesc<tensor<16x128xf32>>, tensor<16x128xf32>
%desc1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f32>, <tensor<16x128xf32>>
tt.descriptor_store %desc1[%c8_i32, %c64_i32], %cst : !tt.tensordesc<tensor<16x128xf32>>, tensor<16x128xf32>
tt.descriptor_store %desc1[%c8_i32, %c64_i32], %cst1 : !tt.tensordesc<tensor<16x128xf32>>, tensor<16x128xf32, #blocked>
tt.return
}
// CHECK: tt.func public @test_store([[PARAM_0:%.+]]: !tt.ptr<f32>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) {
Expand All @@ -40,10 +48,13 @@ module {
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
// CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32
// CHECK-DAG: [[CST:%.+]] = arith.constant dense<1.000000e+00> : tensor<16x128xf32>
// CHECK-DAG: [[CST1:%.+]] = arith.constant dense<1.000000e+00> : tensor<16x128xf32, #[[$BLOCKED]]>
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
// CHECK: tt.store [[TENSOR_PTR]], [[CST]] : !tt.ptr<tensor<16x128xf32>>
// CHECK: [[TENSOR_PTR1:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
// CHECK: tt.store [[TENSOR_PTR1]], [[CST]] : !tt.ptr<tensor<16x128xf32>>
// CHECK: [[TENSOR_PTR2:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32, #[[$BLOCKED]]>>
// CHECK: tt.store [[TENSOR_PTR2]], [[CST1]] : !tt.ptr<tensor<16x128xf32, #[[$BLOCKED]]>>
// CHECK: tt.return
// CHECK: }
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
#include "intel/include/Dialect/Triton/Transforms/Passes.h"
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "triton-intel-tdesc-to-block-pointer"
Expand Down Expand Up @@ -121,11 +120,19 @@ struct TritonIntelTensorDescToBlockPointer
return (yieldedVal != blockArg);
}

// Create a new block pointer if a suitable one doesn't already exist.
// Otherwise, return the existing one. The function takes the base, shape,
// strides, offsets, sizes of the block pointer to create/lookup and its
// tensor element type (to ensure the block pointer has the tensor layout).
Value findOrCreateMakeTensorPtr(Location loc, Value base, ValueRange shape,
ValueRange strides, ValueRange offsets,
ArrayRef<int32_t> sizes, OpBuilder &builder) {
ArrayRef<int32_t> sizes,
RankedTensorType tensorType,
OpBuilder &builder) {
Block *block = builder.getInsertionBlock();
const Block::iterator insertPoint = builder.getInsertionPoint();
auto ptrType = tt::PointerType::get(
tensorType, tt::TritonGEN::TritonGENMemorySpace::kCrossWorkgroup);

auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) {
if (auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op)) {
Expand All @@ -138,7 +145,9 @@ struct TritonIntelTensorDescToBlockPointer
}
return true;
};
return makeTensorPtrOp.getBase() == base &&

return makeTensorPtrOp.getType() == ptrType &&
makeTensorPtrOp.getBase() == base &&
makeTensorPtrOp.getShape() == shape &&
makeTensorPtrOp.getStrides() == strides &&
makeTensorPtrOp.getOffsets() == offsets &&
Expand All @@ -147,10 +156,16 @@ struct TritonIntelTensorDescToBlockPointer
return false;
});

auto makeTensorPtrOp = [&]() {
Value makeTensorPtr = builder.create<tt::MakeTensorPtrOp>(
loc, base, shape, strides, offsets, sizes,
builder.getDenseI32ArrayAttr({1, 0}));
makeTensorPtr.setType(ptrType);
return makeTensorPtr;
};

return (it != insertPoint) ? cast<tt::MakeTensorPtrOp>(*it)
: builder.createOrFold<tt::MakeTensorPtrOp>(
loc, base, shape, strides, offsets, sizes,
builder.getDenseI32ArrayAttr({1, 0}));
: makeTensorPtrOp();
}

template <typename OpTy,
Expand All @@ -176,6 +191,11 @@ struct TritonIntelTensorDescToBlockPointer

LLVM_DEBUG(llvm::dbgs() << "which has tdesc: " << makeTensorDescOp << "\n");

auto createPointerType = [](RankedTensorType tensorType) {
return tt::PointerType::get(
tensorType, tt::TritonGEN::TritonGENMemorySpace::kCrossWorkgroup);
};

// Create a new block pointer if a suitable one doesn't already exist.
SmallVector<Value> shapes, strides, offsets;
SmallVector<int32_t> sizes;
Expand All @@ -193,16 +213,22 @@ struct TritonIntelTensorDescToBlockPointer
sizes.push_back(static_cast<int32_t>(size));
}

constexpr bool isLoad = std::is_same_v<OpTy, tt::DescriptorLoadOp>;
RankedTensorType tensorType;
if constexpr (isLoad)
tensorType = op.getResult().getType();
else
tensorType = op.getSrc().getType();

Value makeTensorPtrOp =
findOrCreateMakeTensorPtr(loc, makeTensorDescOp.getBase(), shapes,
strides, offsets, sizes, builder);
strides, offsets, sizes, tensorType, builder);

LLVM_DEBUG({
llvm::dbgs() << "With:\n";
llvm::dbgs().indent(2) << makeTensorPtrOp << "\n";
});

constexpr bool isLoad = std::is_same_v<OpTy, tt::DescriptorLoadOp>;
if constexpr (isLoad) {
auto loadOp = builder.createOrFold<tt::LoadOp>(
loc, makeTensorPtrOp, op.getCache(), op.getEvict(),
Expand Down