|
| 1 | +#include "Halide.h" |
| 2 | +#include <fstream> |
| 3 | +#include <regex> |
| 4 | + |
| 5 | +using namespace Halide; |
| 6 | + |
| 7 | +bool compile_and_check_vscale(Func &f, |
| 8 | + const std::string &name, |
| 9 | + const Target &t, |
| 10 | + int exp_vscale, |
| 11 | + const std::string &exp_intrin) { |
| 12 | + |
| 13 | + // Look into llvm-ir and check function attributes for vscale_range |
| 14 | + auto llvm_file_name = name + ".ll"; |
| 15 | + f.compile_to_llvm_assembly(llvm_file_name, f.infer_arguments(), t); |
| 16 | + |
| 17 | + Internal::assert_file_exists(llvm_file_name); |
| 18 | + std::ifstream llvm_file; |
| 19 | + llvm_file.open(llvm_file_name); |
| 20 | + std::string line; |
| 21 | + // Pattern to extract "n" and "m" in "vscale_range(n,m)" |
| 22 | + std::regex vscale_regex(R"(vscale_range\(\s*([0-9]+)\s*,\s*([0-9]+)\s*\))"); |
| 23 | + |
| 24 | + int act_vscale = 0; |
| 25 | + bool intrin_found = false; |
| 26 | + |
| 27 | + while (getline(llvm_file, line)) { |
| 28 | + // Check vscale_range |
| 29 | + std::smatch match; |
| 30 | + if (std::regex_search(line, match, vscale_regex) && match[1] == match[2]) { |
| 31 | + act_vscale = std::stoi(match[1]); |
| 32 | + } |
| 33 | + // Check intrin |
| 34 | + if (line.find(exp_intrin) != std::string::npos) { |
| 35 | + intrin_found = true; |
| 36 | + } |
| 37 | + } |
| 38 | + |
| 39 | + if (act_vscale != exp_vscale) { |
| 40 | + printf("[%s] Found vscale_range %d, while expected %d\n", name.c_str(), act_vscale, exp_vscale); |
| 41 | + return false; |
| 42 | + } |
| 43 | + if (!intrin_found) { |
| 44 | + printf("[%s] Cannot find expected intrin %s\n", name.c_str(), exp_intrin.c_str()); |
| 45 | + return false; |
| 46 | + } |
| 47 | + return true; |
| 48 | +} |
| 49 | + |
| 50 | +Var x("x"), y("y"); |
| 51 | + |
| 52 | +bool test_vscale(int vectorization_factor, int vector_bits, int exp_vscale) { |
| 53 | + Func f("f"); |
| 54 | + f(x, y) = absd(x, y); |
| 55 | + f.compute_root().vectorize(x, vectorization_factor); |
| 56 | + |
| 57 | + Target t("arm-64-linux-sve2-no_asserts-no_runtime-no_bounds_query"); |
| 58 | + t.vector_bits = vector_bits; |
| 59 | + |
| 60 | + std::stringstream name; |
| 61 | + name << "test_vscale_v" << vectorization_factor << "_vector_bits_" << vector_bits; |
| 62 | + |
| 63 | + // sve or neon |
| 64 | + std::string intrin = exp_vscale > 0 ? "llvm.aarch64.sve.sabd" : "llvm.aarch64.neon.sabd"; |
| 65 | + |
| 66 | + return compile_and_check_vscale(f, name.str(), t, exp_vscale, intrin); |
| 67 | +} |
| 68 | + |
| 69 | +int main(int argc, char **argv) { |
| 70 | + |
| 71 | + bool ok = true; |
| 72 | + |
| 73 | + ok &= test_vscale(4, 128, 1); // Regular case: <vscale x 4 x ty> with vscale=1 |
| 74 | + ok &= test_vscale(3, 128, 0); // Fallback due to odd vectorization factor |
| 75 | + ok &= test_vscale(8, 512, 4); // Regular case: <vscale x 2 x ty> with vscale=4 |
| 76 | + ok &= test_vscale(4, 512, 0); // Fallback due to <vscale x 1 x ty> |
| 77 | + |
| 78 | + if (!ok) { |
| 79 | + return 1; |
| 80 | + } |
| 81 | + printf("Success!\n"); |
| 82 | + return 0; |
| 83 | +} |
0 commit comments