Skip to content

Commit 28e9720

Browse files
committed
Add unit test for lowering the tt.load of tensor of pointers with structured memory.
Fix segfault in LoadOpToBlockIOConversion Signed-off-by: Whitney Tsang <[email protected]>
1 parent 9856962 commit 28e9720

File tree

3 files changed

+192
-9
lines changed

3 files changed

+192
-9
lines changed

python/test/unit/intel/test_block_load.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import itertools
2+
3+
import numpy as np
14
import pytest
25
import torch
36
import pathlib
@@ -7,6 +10,27 @@
710

811

912
@pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [128, 16], [128, 8], [64, 64], [64, 32], [32, 32]])
13+
class DpasLayout:
14+
15+
def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta,
16+
rep_cluster):
17+
self.repeatCount = repeatCount
18+
self.systolic_depth = systolic_depth
19+
self.execution_size = execution_size
20+
self.ops_per_chan = ops_per_chan
21+
self.threads_per_warp = threads_per_warp
22+
self.warps_per_cta = warps_per_cta
23+
self.rep_cluster = rep_cluster
24+
25+
def __str__(self):
26+
return f"#triton_intel_gpu.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>"
27+
28+
29+
def warps_per_cta(layout):
30+
return layout.warps_per_cta
31+
32+
33+
@pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [64, 64], [64, 32], [32, 32]])
1034
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
1135
@pytest.mark.parametrize("transpose", [True, False])
1236
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
@@ -79,3 +103,107 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
79103
kernel[(1, 1, 1)](a, x, b, y)
80104
#import pdb; pdb.set_trace()
81105
assert torch.equal(a, x) and torch.equal(b.T if transpose else b, y)
106+
107+
108+
layouts = [
109+
# Layout for Xe2 and Xe2+
110+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=16,
111+
warps_per_cta=[1, 4], rep_cluster=[1, 2]),
112+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
113+
warps_per_cta=[8, 4], rep_cluster=[4, 2]),
114+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=16,
115+
warps_per_cta=[8, 4], rep_cluster=[1, 1]),
116+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=32,
117+
warps_per_cta=[1, 4], rep_cluster=[1, 2]),
118+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32,
119+
warps_per_cta=[8, 4], rep_cluster=[4, 2]),
120+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=32,
121+
warps_per_cta=[8, 4], rep_cluster=[1, 1]),
122+
# Layout for Xe
123+
]
124+
125+
126+
@pytest.mark.parametrize("M, N", [[M, N] for M, N in itertools.product([32, 64, 128, 256], [32, 64, 128, 256])])
127+
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
128+
@pytest.mark.parametrize("layout", layouts)
129+
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
130+
def test_tensor_pointer_block_load(M, N, dtype_str, layout, device, tmp_path: pathlib.Path):
131+
132+
warps = warps_per_cta(layout)
133+
num_warps = int(np.prod(warps))
134+
threads_per_warp = layout.threads_per_warp
135+
ops_per_chan = layout.ops_per_chan
136+
A_width = 1 if ops_per_chan == 1 else ops_per_chan // 2
137+
B_width = ops_per_chan
138+
139+
ty = {"float32": "f32", "float16": "f16", "int8": "i8"}[dtype_str]
140+
141+
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
142+
143+
ir = f"""
144+
#mma = {layout}
145+
#dot_a = #ttg.dot_op<{{opIdx = 0, parent = #mma, kWidth = {A_width}}}>
146+
#dot_b = #ttg.dot_op<{{opIdx = 1, parent = #mma, kWidth = {B_width}}}>
147+
module attributes {{triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, {"triton_intel_gpu.support_sg_2d_block," if support_block_io else ""} triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{
148+
tt.func public @tensor_pointer_block_load(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg6: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg7: i32 {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
149+
// A matrix
150+
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>>
151+
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>> -> tensor<{M}x1xi32, #dot_a>
152+
%3 = tt.splat %arg6 : i32 -> tensor<{M}x1xi32, #dot_a>
153+
%4 = arith.muli %2, %3 : tensor<{M}x1xi32, #dot_a>
154+
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>>
155+
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{N}xi32, #dot_a>
156+
%7 = tt.broadcast %4 : tensor<{M}x1xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
157+
%8 = tt.broadcast %6 : tensor<1x{N}xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
158+
%9 = arith.addi %7, %8 : tensor<{M}x{N}xi32, #dot_a>
159+
160+
%10 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
161+
%11 = tt.addptr %10, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
162+
%12 = tt.load %11 {{triton_intel_gpu.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
163+
%13 = tt.splat %arg1 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
164+
%14 = tt.addptr %13, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
165+
tt.store %14, %12 {{boundaryCheck = array<i32: 0, 1>}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
166+
167+
// B matrix
168+
%22 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>>
169+
%44 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>>
170+
%46 = tt.expand_dims %44 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor<{M}x1xi32, #dot_b>
171+
%48 = tt.splat %arg7 : i32 -> tensor<{M}x1xi32, #dot_b>
172+
%49 = arith.muli %46, %48 : tensor<{M}x1xi32, #dot_b>
173+
%50 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{N}xi32, #dot_b>
174+
%51 = tt.broadcast %49 : tensor<{M}x1xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
175+
%52 = tt.broadcast %50 : tensor<1x{N}xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
176+
%53 = arith.addi %51, %52 : tensor<{M}x{N}xi32, #dot_b>
177+
178+
%54 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
179+
%55 = tt.addptr %54, %53 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
180+
%56 = tt.load %55 {{triton_intel_gpu.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
181+
%57 = tt.splat %arg3 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
182+
%58 = tt.addptr %57, %53 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
183+
tt.store %58, %56 {{boundaryCheck = array<i32: 0, 1>}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
184+
185+
tt.return
186+
}}
187+
}}
188+
"""
189+
190+
torch_dtype = getattr(torch, dtype_str)
191+
if torch_dtype.is_floating_point:
192+
a = torch.randn((M, N), dtype=torch_dtype, device=device)
193+
else:
194+
a = torch.randint(low=-127, high=128, size=(M, N), dtype=torch_dtype, device=device)
195+
196+
x = torch.empty_like(a)
197+
y = torch.empty_like(a)
198+
199+
temp_file = tmp_path / "test_tensor_pointer_block_load.ttgir"
200+
temp_file.write_text(ir)
201+
kernel = triton.compile(str(temp_file))
202+
203+
if support_block_io:
204+
# assert '2d block io' in kernel.asm['llir']
205+
pass
206+
207+
kernel[(1, 1, 1)](a, x, a.stride(0), a, y, a.stride(0))
208+
209+
assert torch.equal(a, x) and torch.equal(a, y)

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,39 @@ loadCacheControlToCacheControls(Builder &builder,
110110
return builder.getAttr<TritonGEN::DecorationCacheControlAttr>(decorations);
111111
}
112112

113+
[[maybe_unused]] static bool
114+
isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
115+
VectorType resTy = op.getRes().getType();
116+
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
117+
bool needsResElemSizeEqualTo32 =
118+
op.getElemSizeInBits() == 32 || op.getVnniTransform();
119+
assert((!needsResElemSizeEqualTo32 || resElemTySize == 32) &&
120+
"Expecting 32-bit element type");
121+
if (!needsResElemSizeEqualTo32 && resElemTySize != 16)
122+
return false;
123+
124+
if (op.getVnniTransform())
125+
return true;
126+
127+
if (op.getTranspose() && op.getTileHeight() != 16)
128+
return false;
129+
130+
uint32_t tileWidth = op.getTileWidth();
131+
uint32_t tileHeight = op.getTileHeight();
132+
switch (op.getElemSizeInBits()) {
133+
case 8:
134+
return (tileWidth == 32);
135+
case 16:
136+
return (tileWidth == 16) && (tileHeight != 32);
137+
case 32:
138+
return (tileWidth == 8 || tileWidth == 16) && (tileHeight != 32);
139+
default:
140+
llvm_unreachable("unexpected element size");
141+
}
142+
143+
return false;
144+
}
145+
113146
[[maybe_unused]] static Value
114147
createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
115148
ConversionPatternRewriter &rewriter) {
@@ -119,12 +152,20 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
119152
auto b = TritonLLVMOpBuilder(loc, rewriter);
120153

121154
Value ptr = op.getPtr();
122-
Value baseWidth = op.getBaseWidth();
123155
Value baseHeight = op.getBaseHeight();
124156
Value basePitch = op.getBasePitch();
125-
Value x = op.getX();
126157
Value y = op.getY();
127158

159+
// compensate the non-64 byte aligned base.
160+
Value offset =
161+
b.trunc(i32_ty, b.and_(b.ptrtoint(i64_ty, ptr), b.i64_val(0x3f)));
162+
// In number of bytes.
163+
Value baseWidth = b.add(op.getBaseWidth(), offset);
164+
// In number of scalar elements.
165+
Value offsetX =
166+
b.add(op.getX(),
167+
b.lshr(offset, b.i32_val(std::log2(op.getElemSizeInBits() / 8))));
168+
128169
std::string funcName =
129170
"llvm.genx.GenISA.LSC2DBlockRead." + getGenISATypeMangling(resType);
130171
IntegerType int1Ty = rewriter.getIntegerType(1);
@@ -139,7 +180,7 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
139180
baseWidth.getType(),
140181
baseHeight.getType(),
141182
basePitch.getType(),
142-
x.getType(),
183+
offsetX.getType(),
143184
y.getType(),
144185
int32Ty,
145186
int32Ty,
@@ -153,7 +194,7 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
153194
b.sub(baseWidth, one),
154195
b.sub(baseHeight, one),
155196
b.sub(basePitch, one),
156-
x,
197+
offsetX,
157198
y,
158199
b.i32_val(op.getElemSizeInBits()),
159200
b.i32_val(op.getTileWidth()),
@@ -421,8 +462,9 @@ struct TritonMatrix2DBlockLoadLowering
421462
LogicalResult
422463
matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
423464
ConversionPatternRewriter &rewriter) const override {
424-
if (op.getElemSizeInBits() == 8 && op.getTileWidth() == 16 &&
425-
op.getVBlocks() != 4 && !op.getVnniTransform()) {
465+
if (!isOCLBuiltinAvailable(op) ||
466+
op.getElemSizeInBits() == 8 && op.getTileWidth() == 16 &&
467+
op.getVBlocks() != 4 && !op.getVnniTransform()) {
426468
// TODO: add ocl builtin/spirv intrinsics for 8b 16 column 1 vBlock & 2
427469
// vBlock reads
428470
rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter));

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,11 @@ struct LoadOpToBlockIOConversion
494494
LogicalResult
495495
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
496496
ConversionPatternRewriter &rewriter) const final {
497+
ModuleOp mod = op->getParentOfType<ModuleOp>();
498+
if (!mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
499+
getSupportSG2DBlockAttrName()))
500+
return failure();
501+
497502
Attribute blockIOAttr =
498503
op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName());
499504
if (!blockIOAttr)
@@ -727,6 +732,10 @@ struct LoadOpToBlockIOConversion
727732
if (otherElems.size())
728733
others[offset] = otherElems[i];
729734
}
735+
// ptrs[{0, 0}] and ptrs[{1, 0}] are currently used to calculate the
736+
// pitch.
737+
if (ptrs.count({0, 0}) < 1 || ptrs.count({1, 0}) < 1)
738+
return failure();
730739
}
731740

732741
unsigned numOperandsPer2DLoadM, numOperandsPer2DloadN;
@@ -769,6 +778,8 @@ struct LoadOpToBlockIOConversion
769778
// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
770779
// by enlarging the vBlocks.
771780
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
781+
if (totalBytesPerRowPerDPASOp > 64)
782+
return failure();
772783
numOperandsPer2DloadN =
773784
std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
774785

@@ -815,6 +826,8 @@ struct LoadOpToBlockIOConversion
815826
StringAttr kWarp = str_attr("warp");
816827
StringAttr kBlock = str_attr("block");
817828

829+
const unsigned originalElemBits = elemSizeInBits;
830+
818831
ValueTable loadVals;
819832
for (int inner = 0; inner < numRepInner;
820833
inner += numOperandsInnerDimPerLoad) {
@@ -884,9 +897,9 @@ struct LoadOpToBlockIOConversion
884897
/*tile_height*/ tileHeight,
885898
/*v_blocks*/ vBlocks,
886899
/*transpose*/ false,
887-
/*vnni_transform*/ opIdx ==
888-
DpasEncodingAttr::OpIdx::OperandB &&
889-
usePackedType);
900+
/*vnni_transform*/
901+
(usePackedType && !isOperandA && !isTransposeRequired &&
902+
originalElemBits != 32));
890903
return SmallVector<Value, 1>{load2dOp};
891904
});
892905
Value ret = *endBlock.args_begin();

0 commit comments

Comments
 (0)