Skip to content

Commit 7076fcb

Browse files
committed
Support tensor of pointer as the pointer parameter of the prefetching operation. Add a mask operand for boundary check.
1 parent 976bcdd commit 7076fcb

File tree

5 files changed

+359
-8
lines changed

5 files changed

+359
-8
lines changed

test/TritonIntelGPU/ops.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: triton-opt %s | FileCheck %s
2+
3+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
4+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} {
5+
tt.func public @prefetch_op(%tensor_of_ptr: tensor<64x32x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>,
6+
%block_ptr: !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>, 1>) {
7+
// COM: Mask of tensor type
8+
// CHECK: triton_intel_gpu.prefetch
9+
%mask_tensor = arith.constant dense<1> : tensor<64x32xi1, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>
10+
triton_intel_gpu.prefetch %tensor_of_ptr, %mask_tensor {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>, triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>, tensor<64x32xi1, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>
11+
12+
// COM: Mask of scalar type
13+
// CHECK: triton_intel_gpu.prefetch
14+
%mask_scalar = arith.constant 1 : i1
15+
triton_intel_gpu.prefetch %tensor_of_ptr, %mask_scalar {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>, triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>, i1
16+
17+
// COM: Block pointer includes the boundary information. No mask
18+
// CHECK: triton_intel_gpu.prefetch
19+
triton_intel_gpu.prefetch %block_ptr {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>, 1>
20+
tt.return
21+
}
22+
}

test/TritonIntelGPU/prefetch-to-llvm.mlir

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s
22

33
// CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
44
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
@@ -72,3 +72,99 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
7272
tt.return
7373
}
7474
}
75+
76+
// -----
77+
78+
// CHECK: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i
79+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
80+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
81+
// CHECK-LABEL: llvm.func spir_kernelcc @prefetch_tensor_of_pointers
82+
tt.func public @prefetch_tensor_of_pointers(%tensor_of_ptr: tensor<64x32x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>) {
83+
// CHECK: %[[MASK:.*]] = llvm.mlir.constant(1 : i8) : i8
84+
// CHECK: %[[VAL_2:.*]] = llvm.mlir.undef : vector<2xi32>
85+
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(1 : i32) : i32
86+
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
87+
// CHECK: %[[BASE_HEIGHT:.*]] = llvm.mlir.constant(8 : i32) : i32
88+
// CHECK: %[[BASE_WIDTH:.*]] = llvm.mlir.constant(64 : i32) : i32
89+
// CHECK: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
90+
91+
// CHECK: %[[ADDR_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)>
92+
// CHECK: %[[ADDR_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)>
93+
// CHECK: %[[ADDR_16:.*]] = llvm.extractvalue {{.*}}[16] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)>
94+
// CHECK: %[[ADDR_32:.*]] = llvm.extractvalue {{.*}}[32] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)>
95+
// CHECK: %[[ADDR_48:.*]] = llvm.extractvalue {{.*}}[48] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)>
96+
// CHECK: %[[VAL_13:.*]] = llvm.ptrtoint %[[ADDR_0]] : !llvm.ptr<1> to i64
97+
// CHECK: %[[VAL_14:.*]] = llvm.ptrtoint %[[ADDR_1]] : !llvm.ptr<1> to i64
98+
// CHECK: %[[PITCH:.*]] = llvm.sub %[[VAL_14]], %[[VAL_13]] : i64
99+
// CHECK: %[[UNIFIED_PITCH:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[PITCH]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
100+
// CHECK: %[[UNIFIED_PITCH_I32:.*]] = llvm.trunc %[[UNIFIED_PITCH]] : i64 to i32
101+
// CHECK: %[[VAL_18:.*]] = llvm.intr.umax(%[[UNIFIED_PITCH_I32]], %[[BASE_WIDTH]]) : (i32, i32) -> i32
102+
// CHECK: %[[PITCH_IN_BYTES_I32:.*]] = llvm.trunc %[[VAL_18]] : i32 to i32
103+
104+
// CHECK: %[[UNIFIED_MASK:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[MASK]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i8, i32) -> i8
105+
// CHECK: %[[UNIFIED_MASK_I1:.*]] = llvm.trunc %[[UNIFIED_MASK]] : i8 to i1
106+
// CHECK: %[[OFFSET_Y:.*]] = llvm.select %[[UNIFIED_MASK_I1]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32
107+
// CHECK: %[[UNIFIED_BASE:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_13]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
108+
// CHECK: %[[VAL_26:.*]] = llvm.inttoptr %[[UNIFIED_BASE]] : i64 to !llvm.ptr<1>
109+
// CHECK: %[[VAL_27:.*]] = llvm.insertelement %[[CST_0]], {{.*}} : vector<2xi32>
110+
// CHECK: %[[OFFSETS:.*]] = llvm.insertelement %[[OFFSET_Y]], {{.*}} : vector<2xi32>
111+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_26]], %[[BASE_WIDTH]], %[[BASE_HEIGHT]], %[[PITCH_IN_BYTES_I32]], %[[OFFSETS]])
112+
113+
// CHECK: %[[VAL_29:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[MASK]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i8, i32) -> i8
114+
// CHECK: %[[VAL_30:.*]] = llvm.trunc %[[VAL_29]] : i8 to i1
115+
// CHECK: %[[VAL_31:.*]] = llvm.select %[[VAL_30]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32
116+
// CHECK: %[[VAL_32:.*]] = llvm.ptrtoint %[[ADDR_16]] : !llvm.ptr<1> to i64
117+
// CHECK: %[[VAL_33:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_32]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
118+
// CHECK: %[[VAL_34:.*]] = llvm.inttoptr %[[VAL_33]] : i64 to !llvm.ptr<1>
119+
// CHECK: %[[VAL_35:.*]] = llvm.insertelement %[[VAL_31]], {{.*}} : vector<2xi32>
120+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_34]], %[[BASE_WIDTH]], %[[BASE_HEIGHT]], %[[PITCH_IN_BYTES_I32]], %[[VAL_35]])
121+
122+
// CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[MASK]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i8, i32) -> i8
123+
// CHECK: %[[VAL_37:.*]] = llvm.trunc %[[VAL_36]] : i8 to i1
124+
// CHECK: %[[VAL_38:.*]] = llvm.select %[[VAL_37]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32
125+
// CHECK: %[[VAL_39:.*]] = llvm.ptrtoint %[[ADDR_32]] : !llvm.ptr<1> to i64
126+
// CHECK: %[[VAL_40:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_39]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
127+
// CHECK: %[[VAL_41:.*]] = llvm.inttoptr %[[VAL_40]] : i64 to !llvm.ptr<1>
128+
// CHECK: %[[VAL_42:.*]] = llvm.insertelement %[[VAL_38]], {{.*}} : i32] : vector<2xi32>
129+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_41]], %[[BASE_WIDTH]], %[[BASE_HEIGHT]], %[[PITCH_IN_BYTES_I32]], %[[VAL_42]])
130+
131+
// CHECK: %[[VAL_43:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[MASK]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i8, i32) -> i8
132+
// CHECK: %[[VAL_44:.*]] = llvm.trunc %[[VAL_43]] : i8 to i1
133+
// CHECK: %[[VAL_45:.*]] = llvm.select %[[VAL_44]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32
134+
// CHECK: %[[VAL_46:.*]] = llvm.ptrtoint %[[ADDR_48]] : !llvm.ptr<1> to i64
135+
// CHECK: %[[VAL_47:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_46]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
136+
// CHECK: %[[VAL_48:.*]] = llvm.inttoptr %[[VAL_47]] : i64 to !llvm.ptr<1>
137+
// CHECK: %[[VAL_49:.*]] = llvm.insertelement %[[VAL_45]], {{.*}} : i32] : vector<2xi32>
138+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_48]], %[[BASE_WIDTH]], %[[BASE_HEIGHT]], %[[PITCH_IN_BYTES_I32]], %[[VAL_49]])
139+
140+
%mask_tensor = arith.constant dense<1> : tensor<64x32xi1, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>
141+
triton_intel_gpu.prefetch %tensor_of_ptr, %mask_tensor {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>, triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>, tensor<64x32xi1, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>
142+
143+
// CHECK: %[[VAL_52:.*]] = llvm.select %[[TRUE]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32
144+
// CHECK: %[[VAL_53:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_13]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
145+
// CHECK: %[[VAL_54:.*]] = llvm.inttoptr %[[VAL_53]] : i64 to !llvm.ptr<1>
146+
// CHECK: %[[VAL_55:.*]] = llvm.insertelement %[[VAL_52]], {{.*}} : vector<2xi32>
147+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_54]], {{.*}}, {{.*}}, {{.*}}, %[[VAL_55]])
148+
149+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_32]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
150+
// CHECK: %[[VAL_57:.*]] = llvm.inttoptr %[[VAL_56]] : i64 to !llvm.ptr<1>
151+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_57]], {{.*}}, {{.*}}, {{.*}}, %[[VAL_55]])
152+
153+
// CHECK: %[[VAL_58:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_39]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
154+
// CHECK: %[[VAL_59:.*]] = llvm.inttoptr %[[VAL_58]] : i64 to !llvm.ptr<1>
155+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_59]], {{.*}}, {{.*}}, {{.*}}, %[[VAL_55]])
156+
157+
// CHECK: %[[VAL_60:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_46]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64
158+
// CHECK: %[[VAL_61:.*]] = llvm.inttoptr %[[VAL_60]] : i64 to !llvm.ptr<1>
159+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_61]], {{.*}}, {{.*}}, {{.*}}, %[[VAL_55]])
160+
161+
%mask_scalar = arith.constant 1 : i1
162+
triton_intel_gpu.prefetch %tensor_of_ptr, %mask_scalar {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>, triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>, i1
163+
164+
// CHECK-COUNT-4: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i
165+
166+
triton_intel_gpu.prefetch %tensor_of_ptr {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>, triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>
167+
168+
tt.return
169+
}
170+
}

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def TTIG_ExtractOp : TTIG_Op<"extract", [Pure]> {
107107
let hasFolder = 1;
108108
}
109109

110-
def TTIG_PrefetchOp : TTIG_Op<"prefetch"> {
110+
def TTIG_PrefetchOp : TTIG_Op<"prefetch", [
111+
TypesMatchWith<"mask type matches ptr type or to be bool", "ptr", "mask", "getI1SameShape(getPointeeType($_self))",
112+
"($_op.getOperands().size() <= 1) || ($mask.getType().isSignlessInteger(1)) || std::equal_to<>()">,
113+
]> {
111114
let summary = "Tensor prefetch operation";
112115
let description = [{
113116
The `prefetch` operation prefetches an input tensor.
@@ -117,11 +120,20 @@ def TTIG_PrefetchOp : TTIG_Op<"prefetch"> {
117120
: !tt.ptr<tensor<256x32xf16>
118121
```
119122
}];
120-
let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, TT_CacheModifierAttr:$cache,
121-
TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile);
123+
let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr,
124+
Optional<TT_BoolLike>:$mask,
125+
TT_CacheModifierAttr:$cache,
126+
TT_EvictionPolicyAttr:$evict,
127+
BoolAttr:$isVolatile);
128+
129+
let builders = [
130+
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
131+
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
132+
];
133+
122134
let results = (outs);
123135
let assemblyFormat = [{
124-
operands attr-dict `:` type($ptr)
136+
$ptr (`,` $mask^)? attr-dict `:` type($ptr) (`,` type($mask)^)?
125137
}];
126138
}
127139

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ static unsigned getDimSize(Type type, unsigned dim) {
8686

8787
namespace mlir::triton::gpu::intel {
8888

89+
void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value ptr,
90+
CacheModifier cache, EvictionPolicy evict,
91+
bool isVolatile) {
92+
PrefetchOp::build(builder, state, ptr, /*mask=*/{}, cache, evict, isVolatile);
93+
}
94+
8995
LogicalResult GlueOp::verify() {
9096
SmallVector<Type> inputTypes;
9197
for (Value input : getOperands())

0 commit comments

Comments
 (0)