1414#include " IROperator.h"
1515#include " IRPrinter.h"
1616#include " LLVM_Headers.h"
17+ #include " OptimizeShuffles.h"
1718#include " Simplify.h"
1819#include " Substitute.h"
1920#include " Util.h"
@@ -227,6 +228,7 @@ class CodeGen_ARM : public CodeGen_Posix {
227228 Value *interleave_vectors (const std::vector<Value *> &) override ;
228229 Value *shuffle_vectors (Value *a, Value *b, const std::vector<int > &indices) override ;
229230 Value *shuffle_scalable_vectors_general (Value *a, Value *b, const std::vector<int > &indices);
231+ Value *shuffle_scalable_vectors_general_llvm (Value *a, Value *b, Value *indices, int min_index, int max_index);
230232 Value *codegen_shuffle_indices (int bits, const std::vector<int > &indices);
231233 Value *codegen_whilelt (int total_lanes, int start, int end);
232234 void codegen_vector_reduce (const VectorReduce *, const Expr &) override ;
@@ -1223,6 +1225,22 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f,
12231225 // and a - (b << c) into umlsl/smlsl.
12241226 func.body = distribute_shifts (func.body , /* multiply_adds */ true );
12251227
1228+ if (target_vscale () > 0 ) {
1229+ debug (1 ) << " ARM: Optimizing shuffles...\n " ;
1230+ const int lut_alignment = 16 ;
1231+
1232+ auto max_span_query = [&](const Type &lut_type) -> std::vector<int > {
1233+ int vl = natural_vector_size (lut_type);
1234+ // SVE2 has TBL and TBL2 (TBL with two src vectors) LLVM intrinsic.
1235+ // We prioritize TBL with single src vector in favor of performance.
1236+ return {vl, vl * 2 };
1237+ };
1238+
1239+ func.body = optimize_shuffles (func.body , lut_alignment, native_vector_bits (), max_span_query, true );
1240+ debug (2 ) << " ARM: Lowering after optimizing shuffles:\n "
1241+ << func.body << " \n\n " ;
1242+ }
1243+
12261244 CodeGen_Posix::compile_func (func, simple_name, extern_name);
12271245}
12281246
@@ -2250,7 +2268,7 @@ Value *CodeGen_ARM::shuffle_vectors(Value *a, Value *b, const std::vector<int> &
22502268 }
22512269
22522270 // Perform vector shuffle by decomposing the operation to multiple native shuffle steps
2253- // which calls shuffle_scalable_vectors_general() which emits TBL/TBL2 instruction
2271+ // which calls shuffle_scalable_vectors_general() which emits TBL/TBL2 LLVM intrinsic.
22542272 DecomposeVectorShuffle shuffler (*this , a, b, get_vector_num_elements (a->getType ()), natural_lanes);
22552273 return shuffler.run (indices);
22562274}
@@ -2259,41 +2277,50 @@ Value *CodeGen_ARM::shuffle_scalable_vectors_general(Value *a, Value *b, const s
22592277 internal_assert (a) << " Must provide a valid vector operand" ;
22602278 internal_assert (!indices.empty ()) << " Cannot shuffle with empty indices" ;
22612279
2280+ llvm::Type *elt = get_vector_element_type (a->getType ());
2281+ Value *val_indices = codegen_shuffle_indices (elt->getScalarSizeInBits (), indices);
2282+ auto [min_itr, max_itr] = std::minmax_element (indices.begin (), indices.end ());
2283+ int highest_lane = *max_itr;
2284+ internal_assert (highest_lane >= 0 )
2285+ << " highest_lane was "
2286+ << (highest_lane == SliceIndexNone ? " SliceIndexNone" :
2287+ highest_lane == SliceIndexCarryPrevResult ? " SliceIndexCarryPrevResult" :
2288+ " " )
2289+ << " (" << highest_lane << " )" ;
2290+
2291+ return shuffle_scalable_vectors_general_llvm (a, b, val_indices, *min_itr, *max_itr);
2292+ }
2293+
2294+ Value *CodeGen_ARM::shuffle_scalable_vectors_general_llvm (Value *a, Value *b, Value *indices, int min_index, int max_index) {
2295+ internal_assert (a) << " Must provide a valid vector operand" ;
2296+ internal_assert (indices) << " Must provide a valid indices" ;
2297+
22622298 llvm::Type *elt = get_vector_element_type (a->getType ());
22632299 const int bits = elt->getScalarSizeInBits ();
22642300 const int natural_lanes = natural_vector_size (Int (bits));
22652301 const int src_lanes = get_vector_num_elements (a->getType ());
2266- const int dst_lanes = indices. size ( );
2302+ const int dst_lanes = get_vector_num_elements ( indices-> getType () );
22672303 llvm::Type *dst_type = get_vector_type (elt, dst_lanes);
22682304
22692305 internal_assert (target_vscale () > 0 && is_scalable_vector (a)) << " Only deal with scalable vectors\n " ;
22702306 internal_assert (src_lanes == natural_lanes && dst_lanes == natural_lanes)
22712307 << " Only deal with vector with natural_lanes\n " ;
22722308
22732309 // We select TBL or TBL2 intrinsic depending on indices range
2274- int highest_lane = *std::max_element (indices.begin (), indices.end ());
2275- internal_assert (highest_lane >= 0 )
2276- << " highest_lane was "
2277- << (highest_lane == SliceIndexNone ? " SliceIndexNone" :
2278- highest_lane == SliceIndexCarryPrevResult ? " SliceIndexCarryPrevResult" :
2279- " " )
2280- << " (" << highest_lane << " )" ;
2281-
2282- bool use_tbl = highest_lane < src_lanes;
2310+ const bool use_tbl = max_index < src_lanes;
22832311 internal_assert (use_tbl || b) << " 'b' must be valid in case of tbl2\n " ;
22842312
22852313 auto instr = concat_strings (" llvm.aarch64.sve." , use_tbl ? " tbl" : " tbl2" , mangle_llvm_type (dst_type));
22862314
2287- Value *val_indices = codegen_shuffle_indices (bits, indices);
22882315 llvm::Type *vt_natural = get_vector_type (elt, natural_lanes);
22892316 std::vector<llvm::Type *> llvm_arg_types;
22902317 std::vector<llvm::Value *> llvm_arg_vals;
22912318 if (use_tbl) {
2292- llvm_arg_types = {vt_natural, val_indices ->getType ()};
2293- llvm_arg_vals = {a, val_indices };
2319+ llvm_arg_types = {vt_natural, indices ->getType ()};
2320+ llvm_arg_vals = {a, indices };
22942321 } else {
2295- llvm_arg_types = {vt_natural, vt_natural, val_indices ->getType ()};
2296- llvm_arg_vals = {a, b, val_indices };
2322+ llvm_arg_types = {vt_natural, vt_natural, indices ->getType ()};
2323+ llvm_arg_vals = {a, b, indices };
22972324 }
22982325 llvm::FunctionType *fn_type = FunctionType::get (vt_natural, llvm_arg_types, false );
22992326 FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
@@ -2383,6 +2410,41 @@ void CodeGen_ARM::visit(const Call *op) {
23832410 value = codegen (lower_round_to_nearest_ties_to_even (op->args [0 ]));
23842411 return ;
23852412 }
2413+ } else if (op->is_intrinsic (Call::dynamic_shuffle)) {
2414+ internal_assert (target_vscale () > 0 );
2415+ internal_assert (op->args .size () == 4 );
2416+ const auto min_index = as_const_int (op->args [2 ]);
2417+ const auto max_index = as_const_int (op->args [3 ]);
2418+ internal_assert (min_index.has_value () && max_index.has_value ());
2419+
2420+ Type lut_type = op->args [0 ].type ();
2421+ const int src_lanes = lut_type.lanes ();
2422+ const int dst_lanes = op->args [1 ].type ().lanes ();
2423+ const int natural_lanes = natural_vector_size (lut_type);
2424+
2425+ debug (3 ) << " dynamic_shuffle: [" << *min_index << " , " << *max_index << " ]"
2426+ << " , natural_lanes:" << natural_lanes << " , src_lanes:" << src_lanes << " \n " ;
2427+
2428+ Value *src = codegen (op->args [0 ]);
2429+ internal_assert (src_lanes <= natural_lanes * 2 ) << " src is too long to dynamic_shuffle\n " ;
2430+ Value *src_a = slice_vector (src, 0 , natural_lanes);
2431+ Value *src_b = (src_lanes > natural_lanes) ? slice_vector (src, natural_lanes, natural_lanes) : nullptr ;
2432+
2433+ // Cast index to interger with the same bits as LUT data
2434+ Type index_type = UInt (lut_type.bits ()).with_lanes (dst_lanes);
2435+ Expr indices = cast (index_type, op->args [1 ]);
2436+ Value *val_indices = codegen (indices);
2437+
2438+ std::vector<Value *> slices;
2439+ const int num_slices = align_up (dst_lanes, natural_lanes) / natural_lanes;
2440+ slices.reserve (num_slices);
2441+ for (int i = 0 ; i < num_slices; i++) {
2442+ Value *indices_slice = slice_vector (val_indices, i * natural_lanes, natural_lanes);
2443+ Value *dst_slice = shuffle_scalable_vectors_general_llvm (src_a, src_b, indices_slice, *min_index, *max_index);
2444+ slices.push_back (dst_slice);
2445+ }
2446+ value = slice_vector (concat_vectors (slices), 0 , dst_lanes);
2447+ return ;
23862448 }
23872449
23882450 if (op->type .is_vector ()) {
0 commit comments