Skip to content

Commit 962c421

Browse files
[MLIR][AArch64] Change some tests to ensure SVE vector length is the same throughout the function (#147506)
This change only applies to functions the can be reasonably expected to use SVE registers. Modifying vector length in the middle of a function might cause incorrect stack deallocation if there are callee-saved SVE registers or incorrect access to SVE stack slots. Addresses (non-issue) #143670
1 parent d45d20e commit 962c421

File tree

7 files changed

+35
-22
lines changed

7 files changed

+35
-22
lines changed

mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ extern "C" {
3838
// PR_SVE_VL_LEN_MASK.
3939
#define PR_VL_LEN_MASK 0xffff
4040

41+
/// Sets the vector length (streaming or not, as indicated by `option`) to
42+
/// `bits`.
43+
///
44+
/// Caveat emptor: If a function has allocated stack slots for SVE registers
45+
/// (e.g. slots for callee-saved SVE registers or spill slots) changing
46+
/// the vector length is tricky and error prone - it may cause incorrect stack
47+
/// deallocation or incorrect access to stack slots.
48+
///
49+
/// The recommended strategy is to call `setArmVectorLength` only from functions
50+
/// that do not access SVE registers, either by themselves or by inlining other
51+
/// functions.
4152
static void setArmVectorLength(std::string_view helper_name, int option,
4253
uint32_t bits) {
4354
#if defined(__linux__) && defined(__aarch64__)

mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-scalable-inner-tile.mlir

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,21 @@ func.func @main() {
3939
[ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91, 98, 105, 112]
4040
]> : tensor<7x16xi32>
4141

42+
43+
// Set vscale to 2 (vector width = 256). This will have identical effect to:
44+
// * qemu-aarch64 -cpu max,sve-max-vq=2 (...)
45+
%c256 = arith.constant 256 : i32
46+
func.call @setArmVLBits(%c256) : (i32) -> ()
47+
4248
func.call @pack(%A) : (tensor<7x16xi32>) -> ()
4349

4450
return
4551
}
4652

47-
func.func private @pack(%A: tensor<7x16xi32>) {
53+
func.func private @pack(%A: tensor<7x16xi32>) attributes {no_inline} {
4854
%c1 = arith.constant 1 : index
4955
%pad_val = arith.constant 123 : i32
5056

51-
// Set vscale to 2 (vector width = 256). This will have identical effect to:
52-
// * qemu-aarch64 -cpu max,sve-max-vq=2 (...)
53-
%c256 = arith.constant 256 : i32
54-
func.call @setArmVLBits(%c256) : (i32) -> ()
55-
5657
// Scalable tile size
5758
%vs = vector.vscale
5859
%c8 = arith.constant 8 : index

mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-scalable-inner-tile.mlir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
1111

1212
/// End-to-end test for linalg.pack + linalg.unpack where one of the inner tile sizes is
13-
/// scalable.
13+
/// scalable.
1414
/// NOTE: Vectorization has not been enabled yet!
1515

1616

@@ -21,7 +21,12 @@ func.func @main() {
2121
// (If your platform supports it, you can play with other values as well)
2222
%c256 = arith.constant 256 : i32
2323
func.call @setArmVLBits(%c256) : (i32) -> ()
24+
func.call @test_pack_unpack_scalable_inner_tile() : () -> ()
2425

26+
return
27+
}
28+
29+
func.func @test_pack_unpack_scalable_inner_tile() attributes {no_inline} {
2530
// Dynamic/scalable tile size (vscale x 4)
2631
%c4 = arith.constant 4 : index
2732
%vs = vector.vscale

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-deinterleave.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func.func @entry() {
1717
return
1818
}
1919

20-
func.func @test_deinterleave() {
20+
func.func @test_deinterleave() attributes {no_inline} {
2121
%step_vector = llvm.intr.stepvector : vector<[4]xi8>
2222
vector.print %step_vector : vector<[4]xi8>
2323
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-setArmVLBits.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
// RUN: %{compile} | %{run} | FileCheck %s
1010

11-
func.func @checkVScale() {
11+
func.func @checkVScale() attributes {no_inline} {
1212
%vscale = vector.vscale
1313
vector.print str "vscale = "
1414
vector.print %vscale : index

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-non-trailing.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
// Test the transfer_read with vector type with a non-trailing scalable
1616
// dimension as transformed by the pattern LegalizeTransferRead.
1717

18-
func.func @transfer_read_scalable_non_trailing(%vs : i32, %M : memref<?x8xi8>) {
19-
func.call @setArmVLBits(%vs) : (i32) -> ()
20-
18+
func.func @transfer_read_scalable_non_trailing(%M : memref<?x8xi8>) attributes {no_inline} {
2119
// Read an LLVM-illegal vector
2220
%c0 = arith.constant 0 : index
2321
%c0_i8 = arith.constant 0 : i8
@@ -56,14 +54,16 @@ func.func @main() {
5654
// CHECK:( 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
5755
vector.print str "Result(VL128):\n"
5856
%c128 = arith.constant 128 : i32
59-
func.call @transfer_read_scalable_non_trailing(%c128, %MM) : (i32, memref<?x8xi8>) -> ()
57+
func.call @setArmVLBits(%c128) : (i32) -> ()
58+
func.call @transfer_read_scalable_non_trailing(%MM) : (memref<?x8xi8>) -> ()
6059

6160
// CHECK-LABEL: Result(VL256):
6261
// CHECK: ( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
6362
// CHECK: ( 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 85, 86, 87, 88 )
6463
vector.print str "Result(VL256):\n"
6564
%c256 = arith.constant 256 : i32
66-
func.call @transfer_read_scalable_non_trailing(%c256, %MM) : (i32, memref<?x8xi8>) -> ()
65+
func.call @setArmVLBits(%c256) : (i32) -> ()
66+
func.call @transfer_read_scalable_non_trailing(%MM) : (memref<?x8xi8>) -> ()
6767

6868
return
6969
}

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func.func private @prepareRHSTestData(%in: vector<4x8xi8>) -> memref<?xi8> {
116116

117117
// CHECK-IR-LABEL: llvm.func @test_smmla
118118
// CHECK-IR-COUNT-4: arm_sve.intr.smmla
119-
func.func @test_smmla() {
119+
func.func @test_smmla() attributes {no_inline} {
120120

121121
%c0 = arith.constant 0 : index
122122
%c0_i32 = arith.constant 0 : i32
@@ -131,10 +131,6 @@ func.func @test_smmla() {
131131
%acc_mem = func.call @prepareAccTestData(%acc_cst) : (vector<4x4xi32>) -> memref<4x?xi32>
132132
%acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_i32 {in_bounds = [true, true]} : memref<4x?xi32>, vector<4x[4]xi32>
133133

134-
// FIXME: Workaround for a crash, see https://github.com/llvm/llvm-project/issues/143670
135-
%acc_cast = memref.cast %acc_mem : memref<4x?xi32> to memref<*xi32>
136-
call @printMemrefI32(%acc_cast) : (memref<*xi32>) -> ()
137-
138134
// LHS test data
139135
%lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
140136
[-20, 17, -32, -47, 37, 22, -7, -21],
@@ -186,7 +182,7 @@ func.func @test_smmla() {
186182

187183
// CHECK-IR-LABEL: llvm.func @test_ummla
188184
// CHECK-IR-COUNT-4: arm_sve.intr.ummla
189-
func.func @test_ummla() {
185+
func.func @test_ummla() attributes {no_inline} {
190186

191187
%c0 = arith.constant 0 : index
192188
%c0_i32 = arith.constant 0 : i32
@@ -253,7 +249,7 @@ func.func @test_ummla() {
253249

254250
// CHECK-IR-LABEL: llvm.func @test_usmmla
255251
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
256-
func.func @test_usmmla() {
252+
func.func @test_usmmla() attributes {no_inline} {
257253

258254
%c0 = arith.constant 0 : index
259255
%c0_i32 = arith.constant 0 : i32
@@ -321,7 +317,7 @@ func.func @test_usmmla() {
321317

322318
// CHECK-IR-LABEL: llvm.func @test_summla
323319
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
324-
func.func @test_summla() {
320+
func.func @test_summla() attributes {no_inline} {
325321

326322
%c0 = arith.constant 0 : index
327323
%c0_i32 = arith.constant 0 : i32

0 commit comments

Comments
 (0)