Skip to content

Commit 53b2a51

Browse files
[TTIG_PrefetchOp] Add mask argument (#4011)
This PR adds a new argument `mask` to the prefetch operation. It is to prepare for handling prefetching of tensor of pointers, as loads from tensor of pointers can be masked. Note: this change comes partially from #3634. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 935ded3 commit 53b2a51

File tree

4 files changed

+42
-4
lines changed

4 files changed

+42
-4
lines changed

test/TritonIntelGPU/tritonintelgpu-invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ tt.func @triton_intel_gpu.extract(%ptr : !tt.ptr<tensor<32x32xf16>>) {
130130

131131
// -----
132132

133+
tt.func @triton_intel_gpu.prefetch(%arg0: !tt.ptr<tensor<2x32xf32>>, %arg1: tensor<4x32xi1>) {
134+
// expected-note@-1 {{prior use here}}
135+
// expected-error@+1 {{use of value '%arg1' expects different type than prior uses: 'tensor<2x32xi1>' vs 'tensor<4x32xi1>'}}
136+
triton_intel_gpu.prefetch %arg0, %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<2x32xf32>>
137+
tt.return
138+
}
139+
140+
// -----
141+
133142
#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
134143

135144
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {

test/TritonIntelGPU/tritonintelgpu.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,17 @@ tt.func @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg
5050

5151
// -----
5252

53+
tt.func @triton_intel_gpu.prefetch(%arg0: !tt.ptr<tensor<2x32xf32>>, %arg1: tensor<2x32xi1>) {
54+
// CHECK-LABEL: @triton_intel_gpu.prefetch
55+
// CHECK: triton_intel_gpu.prefetch %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<2x32xf32>>
56+
triton_intel_gpu.prefetch %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<2x32xf32>>
57+
// CHECK: triton_intel_gpu.prefetch %arg0, %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<2x32xf32>>
58+
triton_intel_gpu.prefetch %arg0, %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<2x32xf32>>
59+
tt.return
60+
}
61+
62+
// -----
63+
5364
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
5465
tt.func @triton_intel_gpu.sub_group_transpose(%local_buffer : !tt.ptr<f16, 3>, %src : tensor<16x16xf16>) -> tensor<16x16xf16> {
5566
// CHECK-LABEL: @triton_intel_gpu.sub_group_transpose

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def TTIG_ExtractOp : TTIG_Op<"extract", [Pure]> {
107107
let hasFolder = 1;
108108
}
109109

110-
def TTIG_PrefetchOp : TTIG_Op<"prefetch"> {
110+
def TTIG_PrefetchOp : TTIG_Op<"prefetch", [
111+
TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))",
112+
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
113+
]> {
111114
let summary = "Tensor prefetch operation";
112115
let description = [{
113116
The `prefetch` operation prefetches an input tensor.
@@ -117,11 +120,20 @@ def TTIG_PrefetchOp : TTIG_Op<"prefetch"> {
117120
: !tt.ptr<tensor<256x32xf16>
118121
```
119122
}];
120-
let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, TT_CacheModifierAttr:$cache,
121-
TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile);
123+
let arguments = (
124+
ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr,
125+
Optional<TT_BoolLike>:$mask,
126+
TT_CacheModifierAttr:$cache,
127+
TT_EvictionPolicyAttr:$evict,
128+
BoolAttr:$isVolatile
129+
);
122130
let results = (outs);
131+
let builders = [
132+
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
133+
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
134+
];
123135
let assemblyFormat = [{
124-
operands attr-dict `:` type($ptr)
136+
$ptr (`,` $mask^)? attr-dict `:` type($ptr)
125137
}];
126138
}
127139

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
197197
return {};
198198
}
199199

200+
void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value ptr,
201+
CacheModifier cache, EvictionPolicy evict,
202+
bool isVolatile) {
203+
PrefetchOp::build(builder, state, ptr, /*mask=*/{}, cache, evict, isVolatile);
204+
}
205+
200206
LogicalResult SubGroupTransposeOp::verify() {
201207
RankedTensorType srcType = getSrc().getType();
202208
auto mod = getOperation()->getParentOfType<mlir::ModuleOp>();

0 commit comments

Comments
 (0)