Skip to content

Commit 95cf920

Browse files
Add tests for fallback SVE depending on vector_bits
1 parent 64aa9e1 commit 95cf920

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

test/correctness/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ tests(GROUPS correctness
110110
extern_stage_on_device.cpp
111111
extract_concat_bits.cpp
112112
failed_unroll.cpp
113+
fallback_vscale_sve.cpp
113114
fast_trigonometric.cpp
114115
fibonacci.cpp
115116
fit_function.cpp
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#include "Halide.h"
2+
#include <fstream>
3+
#include <regex>
4+
5+
using namespace Halide;
6+
7+
bool compile_and_check_vscale(Func &f,
8+
const std::string &name,
9+
const Target &t,
10+
int exp_vscale,
11+
const std::string &exp_intrin) {
12+
13+
// Look into llvm-ir and check function attributes for vscale_range
14+
auto llvm_file_name = name + ".ll";
15+
f.compile_to_llvm_assembly(llvm_file_name, f.infer_arguments(), t);
16+
17+
Internal::assert_file_exists(llvm_file_name);
18+
std::ifstream llvm_file;
19+
llvm_file.open(llvm_file_name);
20+
std::string line;
21+
// Pattern to extract "n" and "m" in "vscale_range(n,m)"
22+
std::regex vscale_regex(R"(vscale_range\(\s*([0-9]+)\s*,\s*([0-9]+)\s*\))");
23+
24+
int act_vscale = 0;
25+
bool intrin_found = false;
26+
27+
while (getline(llvm_file, line)) {
28+
// Check vscale_range
29+
std::smatch match;
30+
if (std::regex_search(line, match, vscale_regex) && match[1] == match[2]) {
31+
act_vscale = std::stoi(match[1]);
32+
}
33+
// Check intrin
34+
if (line.find(exp_intrin) != std::string::npos) {
35+
intrin_found = true;
36+
}
37+
}
38+
39+
if (act_vscale != exp_vscale) {
40+
printf("[%s] Found vscale_range %d, while expected %d\n", name.c_str(), act_vscale, exp_vscale);
41+
return false;
42+
}
43+
if (!intrin_found) {
44+
printf("[%s] Cannot find expected intrin %s\n", name.c_str(), exp_intrin.c_str());
45+
return false;
46+
}
47+
return true;
48+
}
49+
50+
Var x("x"), y("y");
51+
52+
bool test_vscale(int vectorization_factor, int vector_bits, int exp_vscale) {
53+
Func f("f");
54+
f(x, y) = absd(x, y);
55+
f.compute_root().vectorize(x, vectorization_factor);
56+
57+
Target t("arm-64-linux-sve2-no_asserts-no_runtime-no_bounds_query");
58+
t.vector_bits = vector_bits;
59+
60+
std::stringstream name;
61+
name << "test_vscale_v" << vectorization_factor << "_vector_bits_" << vector_bits;
62+
63+
// sve or neon
64+
std::string intrin = exp_vscale > 0 ? "llvm.aarch64.sve.sabd" : "llvm.aarch64.neon.sabd";
65+
66+
return compile_and_check_vscale(f, name.str(), t, exp_vscale, intrin);
67+
}
68+
69+
int main(int argc, char **argv) {
70+
71+
bool ok = true;
72+
73+
ok &= test_vscale(4, 128, 1); // Regular case: <vscale x 4 x ty> with vscale=1
74+
ok &= test_vscale(3, 128, 0); // Fallback due to odd vectorization factor
75+
ok &= test_vscale(8, 512, 4); // Regular case: <vscale x 2 x ty> with vscale=4
76+
ok &= test_vscale(4, 512, 0); // Fallback due to <vscale x 1 x ty>
77+
78+
if (!ok) {
79+
return 1;
80+
}
81+
printf("Success!\n");
82+
return 0;
83+
}

test/warning/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ tests(GROUPS warning
44
require_const_false.cpp
55
sliding_vectors.cpp
66
unscheduled_update_def.cpp
7+
unsupported_vectorization_sve.cpp
78
emulated_float16.cpp
89
)
910

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "Halide.h"
2+
#include "halide_test_dirs.h"
3+
4+
using namespace Halide;
5+
6+
int main(int argc, char **argv) {
7+
Func f;
8+
Var x;
9+
10+
f(x) = x * 0.1f;
11+
12+
constexpr int vscale = 2;
13+
constexpr int vector_bits = 128 * vscale;
14+
15+
f.vectorize(x, vscale * 3);
16+
Target t("arm-64-linux-sve2-vector_bits_" + std::to_string(vector_bits));
17+
18+
// SVE is disabled with user_warning,
19+
// which would have ended up with emitting <vscale x 3 x float> if we didn't.
20+
f.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "unused.ll", f.infer_arguments(), "f", t);
21+
22+
return 0;
23+
}

0 commit comments

Comments
 (0)