Skip to content

Commit 23f4988

Browse files
Support Call::dynamic_shuffle for LUT in SVE2
OptimizeShuffles pass is enabled in CodeGen_ARM for SVE2. - Detects gather load where index range is bounded within certain value e.g. Look Up Table - Transforms it into contiguous load + Call::dynamic_shuffle, which is lowered to TBL instruction by codegen. This is especially useful to vectorize with long vector in SME2 streaming mode where general form of gather load is unsupported. OptimizeShuffles is modified so that we can use it commonly between targets (for now, Hexagon and ARM SVE2).
1 parent 9d5081c commit 23f4988

File tree

6 files changed

+169
-52
lines changed

6 files changed

+169
-52
lines changed

src/CodeGen_ARM.cpp

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "IROperator.h"
1515
#include "IRPrinter.h"
1616
#include "LLVM_Headers.h"
17+
#include "OptimizeShuffles.h"
1718
#include "Simplify.h"
1819
#include "Substitute.h"
1920
#include "Util.h"
@@ -227,6 +228,7 @@ class CodeGen_ARM : public CodeGen_Posix {
227228
Value *interleave_vectors(const std::vector<Value *> &) override;
228229
Value *shuffle_vectors(Value *a, Value *b, const std::vector<int> &indices) override;
229230
Value *shuffle_scalable_vectors_general(Value *a, Value *b, const std::vector<int> &indices);
231+
Value *shuffle_scalable_vectors_general_llvm(Value *a, Value *b, Value *indices, int min_index, int max_index);
230232
Value *codegen_shuffle_indices(int bits, const std::vector<int> &indices);
231233
Value *codegen_whilelt(int total_lanes, int start, int end);
232234
void codegen_vector_reduce(const VectorReduce *, const Expr &) override;
@@ -1223,6 +1225,22 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f,
12231225
// and a - (b << c) into umlsl/smlsl.
12241226
func.body = distribute_shifts(func.body, /* multiply_adds */ true);
12251227

1228+
if (target_vscale() > 0) {
1229+
debug(1) << "ARM: Optimizing shuffles...\n";
1230+
const int lut_alignment = 16;
1231+
1232+
auto max_span_query = [&](const Type &lut_type) -> std::vector<int> {
1233+
int vl = natural_vector_size(lut_type);
1234+
// SVE2 has TBL and TBL2 (TBL with two src vectors) LLVM intrinsic.
1235+
// We prioritize TBL with single src vector in favor of performance.
1236+
return {vl, vl * 2};
1237+
};
1238+
1239+
func.body = optimize_shuffles(func.body, lut_alignment, native_vector_bits(), max_span_query, true);
1240+
debug(2) << "ARM: Lowering after optimizing shuffles:\n"
1241+
<< func.body << "\n\n";
1242+
}
1243+
12261244
CodeGen_Posix::compile_func(func, simple_name, extern_name);
12271245
}
12281246

@@ -2250,7 +2268,7 @@ Value *CodeGen_ARM::shuffle_vectors(Value *a, Value *b, const std::vector<int> &
22502268
}
22512269

22522270
// Perform vector shuffle by decomposing the operation to multiple native shuffle steps
2253-
// which calls shuffle_scalable_vectors_general() which emits TBL/TBL2 instruction
2271+
// which calls shuffle_scalable_vectors_general() which emits TBL/TBL2 LLVM intrinsic.
22542272
DecomposeVectorShuffle shuffler(*this, a, b, get_vector_num_elements(a->getType()), natural_lanes);
22552273
return shuffler.run(indices);
22562274
}
@@ -2259,41 +2277,50 @@ Value *CodeGen_ARM::shuffle_scalable_vectors_general(Value *a, Value *b, const s
22592277
internal_assert(a) << "Must provide a valid vector operand";
22602278
internal_assert(!indices.empty()) << "Cannot shuffle with empty indices";
22612279

2280+
llvm::Type *elt = get_vector_element_type(a->getType());
2281+
Value *val_indices = codegen_shuffle_indices(elt->getScalarSizeInBits(), indices);
2282+
auto [min_itr, max_itr] = std::minmax_element(indices.begin(), indices.end());
2283+
int highest_lane = *max_itr;
2284+
internal_assert(highest_lane >= 0)
2285+
<< "highest_lane was "
2286+
<< (highest_lane == SliceIndexNone ? "SliceIndexNone" :
2287+
highest_lane == SliceIndexCarryPrevResult ? "SliceIndexCarryPrevResult" :
2288+
"")
2289+
<< " (" << highest_lane << ")";
2290+
2291+
return shuffle_scalable_vectors_general_llvm(a, b, val_indices, *min_itr, *max_itr);
2292+
}
2293+
2294+
Value *CodeGen_ARM::shuffle_scalable_vectors_general_llvm(Value *a, Value *b, Value *indices, int min_index, int max_index) {
2295+
internal_assert(a) << "Must provide a valid vector operand";
2296+
internal_assert(indices) << "Must provide a valid indices";
2297+
22622298
llvm::Type *elt = get_vector_element_type(a->getType());
22632299
const int bits = elt->getScalarSizeInBits();
22642300
const int natural_lanes = natural_vector_size(Int(bits));
22652301
const int src_lanes = get_vector_num_elements(a->getType());
2266-
const int dst_lanes = indices.size();
2302+
const int dst_lanes = get_vector_num_elements(indices->getType());
22672303
llvm::Type *dst_type = get_vector_type(elt, dst_lanes);
22682304

22692305
internal_assert(target_vscale() > 0 && is_scalable_vector(a)) << "Only deal with scalable vectors\n";
22702306
internal_assert(src_lanes == natural_lanes && dst_lanes == natural_lanes)
22712307
<< "Only deal with vector with natural_lanes\n";
22722308

22732309
// We select TBL or TBL2 intrinsic depending on indices range
2274-
int highest_lane = *std::max_element(indices.begin(), indices.end());
2275-
internal_assert(highest_lane >= 0)
2276-
<< "highest_lane was "
2277-
<< (highest_lane == SliceIndexNone ? "SliceIndexNone" :
2278-
highest_lane == SliceIndexCarryPrevResult ? "SliceIndexCarryPrevResult" :
2279-
"")
2280-
<< " (" << highest_lane << ")";
2281-
2282-
bool use_tbl = highest_lane < src_lanes;
2310+
const bool use_tbl = max_index < src_lanes;
22832311
internal_assert(use_tbl || b) << "'b' must be valid in case of tbl2\n";
22842312

22852313
auto instr = concat_strings("llvm.aarch64.sve.", use_tbl ? "tbl" : "tbl2", mangle_llvm_type(dst_type));
22862314

2287-
Value *val_indices = codegen_shuffle_indices(bits, indices);
22882315
llvm::Type *vt_natural = get_vector_type(elt, natural_lanes);
22892316
std::vector<llvm::Type *> llvm_arg_types;
22902317
std::vector<llvm::Value *> llvm_arg_vals;
22912318
if (use_tbl) {
2292-
llvm_arg_types = {vt_natural, val_indices->getType()};
2293-
llvm_arg_vals = {a, val_indices};
2319+
llvm_arg_types = {vt_natural, indices->getType()};
2320+
llvm_arg_vals = {a, indices};
22942321
} else {
2295-
llvm_arg_types = {vt_natural, vt_natural, val_indices->getType()};
2296-
llvm_arg_vals = {a, b, val_indices};
2322+
llvm_arg_types = {vt_natural, vt_natural, indices->getType()};
2323+
llvm_arg_vals = {a, b, indices};
22972324
}
22982325
llvm::FunctionType *fn_type = FunctionType::get(vt_natural, llvm_arg_types, false);
22992326
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
@@ -2383,6 +2410,41 @@ void CodeGen_ARM::visit(const Call *op) {
23832410
value = codegen(lower_round_to_nearest_ties_to_even(op->args[0]));
23842411
return;
23852412
}
2413+
} else if (op->is_intrinsic(Call::dynamic_shuffle)) {
2414+
internal_assert(target_vscale() > 0);
2415+
internal_assert(op->args.size() == 4);
2416+
const auto min_index = as_const_int(op->args[2]);
2417+
const auto max_index = as_const_int(op->args[3]);
2418+
internal_assert(min_index.has_value() && max_index.has_value());
2419+
2420+
Type lut_type = op->args[0].type();
2421+
const int src_lanes = lut_type.lanes();
2422+
const int dst_lanes = op->args[1].type().lanes();
2423+
const int natural_lanes = natural_vector_size(lut_type);
2424+
2425+
debug(3) << "dynamic_shuffle: [" << *min_index << ", " << *max_index << "]"
2426+
<< ", natural_lanes:" << natural_lanes << ", src_lanes:" << src_lanes << "\n";
2427+
2428+
Value *src = codegen(op->args[0]);
2429+
internal_assert(src_lanes <= natural_lanes * 2) << "src is too long to dynamic_shuffle\n";
2430+
Value *src_a = slice_vector(src, 0, natural_lanes);
2431+
Value *src_b = (src_lanes > natural_lanes) ? slice_vector(src, natural_lanes, natural_lanes) : nullptr;
2432+
2433+
// Cast index to interger with the same bits as LUT data
2434+
Type index_type = UInt(lut_type.bits()).with_lanes(dst_lanes);
2435+
Expr indices = cast(index_type, op->args[1]);
2436+
Value *val_indices = codegen(indices);
2437+
2438+
std::vector<Value *> slices;
2439+
const int num_slices = align_up(dst_lanes, natural_lanes) / natural_lanes;
2440+
slices.reserve(num_slices);
2441+
for (int i = 0; i < num_slices; i++) {
2442+
Value *indices_slice = slice_vector(val_indices, i * natural_lanes, natural_lanes);
2443+
Value *dst_slice = shuffle_scalable_vectors_general_llvm(src_a, src_b, indices_slice, *min_index, *max_index);
2444+
slices.push_back(dst_slice);
2445+
}
2446+
value = slice_vector(concat_vectors(slices), 0, dst_lanes);
2447+
return;
23862448
}
23872449

23882450
if (op->type.is_vector()) {

src/CodeGen_Hexagon.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,9 @@ void CodeGen_Hexagon::visit(const Call *op) {
19351935
auto max_index = as_const_int(op->args[3]);
19361936
internal_assert(min_index && max_index);
19371937
Value *lut = codegen(op->args[0]);
1938-
Value *idx = codegen(op->args[1]);
1938+
// Cast the index to 8 bit
1939+
Expr index = cast(UInt(8).with_lanes(op->type.lanes()), op->args[1]);
1940+
Value *idx = codegen(index);
19391941
value = vlut(lut, idx, *min_index, *max_index);
19401942
return;
19411943
} else if (op->is_intrinsic(Call::abs)) {

src/HexagonOptimize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2285,7 +2285,8 @@ class SyncronizationBarriers : public IRMutator {
22852285
Stmt optimize_hexagon_shuffles(const Stmt &s, int lut_alignment) {
22862286
// Replace indirect and other complicated loads with
22872287
// dynamic_shuffle (vlut) calls.
2288-
return optimize_shuffles(s, lut_alignment);
2288+
auto max_span_query = [](const Type &t) -> std::vector<int> { return {256}; };
2289+
return optimize_shuffles(s, lut_alignment, 1024, max_span_query, false);
22892290
}
22902291

22912292
Stmt scatter_gather_generator(Stmt s) {

src/OptimizeShuffles.cpp

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ namespace Internal {
2121

2222
namespace {
2323

24+
using SpanQueryType = std::function<std::vector<int>(const Type &)>;
25+
2426
class OptimizeShuffles : public IRMutator {
2527
int lut_alignment;
28+
int native_vector_bits;
29+
SpanQueryType get_max_span_sizes;
30+
bool align_loads_with_native_vector;
2631
Scope<Interval> bounds;
2732
std::vector<std::pair<std::string, Expr>> lets;
2833

@@ -67,7 +72,7 @@ class OptimizeShuffles : public IRMutator {
6772
if (allocations_to_pad.count(op->name)) {
6873
op = s.as<Allocate>();
6974
internal_assert(op);
70-
int padding = 128 / op->type.bytes(); // One native vector
75+
int padding = native_vector_bits / op->type.bits(); // One native vector
7176
return Allocate::make(op->name, op->type, op->memory_type,
7277
op->extents, op->condition,
7378
op->body, op->new_expr, op->free_function,
@@ -99,34 +104,40 @@ class OptimizeShuffles : public IRMutator {
99104
((unaligned_index_bounds.max + align) / align) * align - 1};
100105
ModulusRemainder alignment(align, 0);
101106

102-
for (const Interval &index_bounds : {aligned_index_bounds, unaligned_index_bounds}) {
103-
Expr index_span = span_of_bounds(index_bounds);
104-
index_span = common_subexpression_elimination(index_span);
105-
index_span = simplify(index_span);
106-
107-
if (can_prove(index_span < 256)) {
108-
// This is a lookup within an up to 256 element array. We
109-
// can use dynamic_shuffle for this.
110-
int const_extent = as_const_int(index_span) ? *as_const_int(index_span) + 1 : 256;
111-
Expr base = simplify(index_bounds.min);
112-
113-
// Load all of the possible indices loaded from the
114-
// LUT. Note that for clamped ramps, this loads up to 1
115-
// vector past the max, so we will add padding to the
116-
// allocation accordingly (if we're the one that made it).
117-
allocations_to_pad.insert(op->name);
118-
Expr lut = Load::make(op->type.with_lanes(const_extent), op->name,
119-
Ramp::make(base, 1, const_extent),
120-
op->image, op->param, const_true(const_extent), alignment);
121-
122-
// We know the size of the LUT is not more than 256, so we
123-
// can safely cast the index to 8 bit, which
124-
// dynamic_shuffle requires.
125-
index = simplify(cast(UInt(8).with_lanes(op->type.lanes()), index - base));
126-
return Call::make(op->type, "dynamic_shuffle", {lut, index, 0, const_extent - 1}, Call::PureIntrinsic);
107+
const int native_vector_size = native_vector_bits / op->type.bits();
108+
109+
for (const auto &max_span_size : get_max_span_sizes(op->type)) {
110+
111+
for (const Interval &index_bounds : {aligned_index_bounds, unaligned_index_bounds}) {
112+
Expr index_span = span_of_bounds(index_bounds);
113+
index_span = common_subexpression_elimination(index_span);
114+
index_span = simplify(index_span);
115+
116+
if (can_prove(index_span < max_span_size)) {
117+
// This is a lookup within an up to max_span_size element array. We
118+
// can use dynamic_shuffle for this.
119+
int const_extent = as_const_int(index_span) ? *as_const_int(index_span) + 1 : max_span_size;
120+
if (align_loads_with_native_vector) {
121+
const_extent = align_up(const_extent, native_vector_size);
122+
}
123+
Expr base = simplify(index_bounds.min);
124+
125+
// Load all of the possible indices loaded from the
126+
// LUT. Note that for clamped ramps, this loads up to 1
127+
// vector past the max, so we will add padding to the
128+
// allocation accordingly (if we're the one that made it).
129+
allocations_to_pad.insert(op->name);
130+
Expr lut = Load::make(op->type.with_lanes(const_extent), op->name,
131+
Ramp::make(base, 1, const_extent),
132+
op->image, op->param, const_true(const_extent), alignment);
133+
134+
// Target dependent codegen needs to cast the type of index to what it accepts
135+
index = simplify(index - base);
136+
return Call::make(op->type, "dynamic_shuffle", {lut, index, 0, const_extent - 1}, Call::PureIntrinsic);
137+
}
138+
// Only the first iteration of this loop is aligned.
139+
alignment = ModulusRemainder();
127140
}
128-
// Only the first iteration of this loop is aligned.
129-
alignment = ModulusRemainder();
130141
}
131142
}
132143
if (!index.same_as(op->index)) {
@@ -137,14 +148,17 @@ class OptimizeShuffles : public IRMutator {
137148
}
138149

139150
public:
140-
OptimizeShuffles(int lut_alignment)
141-
: lut_alignment(lut_alignment) {
151+
OptimizeShuffles(int lut_alignment, int native_vector_bits, SpanQueryType get_max_span_sizes, bool align_loads_with_native_vector)
152+
: lut_alignment(lut_alignment),
153+
native_vector_bits(native_vector_bits),
154+
get_max_span_sizes(std::move(get_max_span_sizes)),
155+
align_loads_with_native_vector(align_loads_with_native_vector) {
142156
}
143157
};
144158
} // namespace
145159

146-
Stmt optimize_shuffles(Stmt s, int lut_alignment) {
147-
s = OptimizeShuffles(lut_alignment)(s);
160+
Stmt optimize_shuffles(Stmt s, int lut_alignment, int native_vector_bits, SpanQueryType get_max_span_sizes, bool align_loads_with_native_vector) {
161+
s = OptimizeShuffles(lut_alignment, native_vector_bits, std::move(get_max_span_sizes), align_loads_with_native_vector)(s);
148162
return s;
149163
}
150164

src/OptimizeShuffles.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@
77
*/
88

99
#include "Expr.h"
10+
#include <functional>
11+
#include <vector>
1012

1113
namespace Halide {
1214
namespace Internal {
1315

1416
/* Replace indirect loads with dynamic_shuffle intrinsics where
1517
possible. */
16-
Stmt optimize_shuffles(Stmt s, int lut_alignment);
18+
Stmt optimize_shuffles(Stmt s,
19+
int lut_alignment,
20+
int native_vector_bits,
21+
std::function<std::vector<int>(const Type &)> get_max_span_sizes,
22+
bool align_loads_with_native_vector);
1723

1824
} // namespace Internal
1925
} // namespace Halide

test/correctness/simd_op_check_sve2.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
855855
if (instr_lanes < 2 || (total_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>
856856

857857
AddTestFunctor add(*this, bits, total_lanes);
858-
Expr index = clamp(cast<int>(in_im(x)), 0, W - 1);
858+
Expr index = clamp(in_i32(x), 0, W - 1);
859859
Func tmp;
860860
tmp(x, y) = cast(elt, y);
861861
tmp(x, index) = cast(elt, 1);
@@ -876,6 +876,38 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
876876
}
877877
}
878878
}
879+
880+
// Gather load where index range is bounded within certain value. e.g. LUT
881+
// In this case, Halide tries to tranform it into contiguous load + Call::dynamic_shuffle
882+
// which is lowered to TBL instruction. (see OptimizeShuffles.cpp)
883+
if (has_sve()) {
884+
const int width = base_vec_bits;
885+
const int total_lanes = width / bits;
886+
const int instr_lanes = Instruction::get_instr_lanes(bits, total_lanes, target);
887+
if (instr_lanes < 2 || (total_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>
888+
889+
AddTestFunctor add(*this, bits, total_lanes);
890+
const std::vector<std::pair<int, int>> index_min_max{
891+
{0, total_lanes - 1},
892+
{1, total_lanes},
893+
{0, total_lanes * 2 - 1},
894+
};
895+
for (auto &[index_min, index_max] : index_min_max) {
896+
Expr index = cast(Int(32), in_im(x));
897+
index = clamp(index, index_min, index_max);
898+
Expr look_up = in_im(index);
899+
900+
add("tbl", look_up);
901+
}
902+
903+
// Without clamped but bounded by the range of the data type of the input image (8bit)
904+
Expr index = cast(Int(32), in_u8(x)); // 8 bit fixed
905+
int factor = (1 << 8) / (total_lanes * 2);
906+
index = index / factor; // index should be within native_vector*2 range
907+
Expr look_up = in_im(index);
908+
909+
add("tbl", look_up);
910+
}
879911
}
880912
}
881913

0 commit comments

Comments
 (0)