@@ -5013,11 +5013,315 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
50135013 ctx.target_shape ));
50145014 ArrayRef<int64_t > permutation = transpose_op.getPermutation ();
50155015 const auto tile_perm = permutation.take_back (2 );
5016+
5017+ // Major minor pemute
50165018 if (tile_perm != ArrayRef<int64_t >{rank - 2 , rank - 1 } &&
50175019 tile_perm != ArrayRef<int64_t >{rank - 1 , rank - 2 }) {
5018- return transpose_op->emitOpError (
5019- " Not implemented: Unsupported permutation" );
5020+ // This is a 3 stage algorithm that uses combinations and shuffles
5021+ // to do a transposition of an 8x8 block of sublanes.
5022+ // In the following algorithm description, A, B, ..., H represent 8
5023+ // distinct input vregs that form an 8x8 block of data
5024+ // to be transposed. In our notation, B2 identifies the third
5025+ // sublane (2) of the second vreg (B)".
5026+ //
5027+ //
5028+ // If we think of each starting input vreg as a row in an 8x8 block of
5029+ // elements:
5030+ // A: A0 A1 A2 A3 A4 A5 A6 A7
5031+ // B: B0 B1 B2 B3 B4 B5 B6 B7
5032+ // ...
5033+ // H: H0 H1 H2 H3 H4 H5 H6 H7
5034+ //
5035+ // The goal is to transpose this block, so the output vregs are:
5036+ // out0: A0 B0 C0 D0 E0 F0 G0 H0
5037+ // out1: A1 B1 C1 D1 E1 F1 G1 H1
5038+ // ...
5039+ // out7: A7 B7 C7 D7 E7 F7 G7 H7
5040+ //
5041+ // Stage 1: Operates on pairs of input vregs (e.g., A and B).
5042+ //
5043+ // Input to Stage 1 (example pair A, B):
5044+ // A: A0 A1 A2 A3 A4 A5 A6 A7
5045+ // B: B0 B1 B2 B3 B4 B5 B6 B7
5046+ //
5047+ // Step 1.1: Combine low/high halves.
5048+ // combine_low(A, B) -> CL_AB: [A0 A1 A2 A3 | B0 B1 B2 B3] (8 elements)
5049+ // combine_high(A, B) -> CH_AB: [A4 A5 A6 A7 | B4 B5 B6 B7] (8 elements)
5050+ // (Notation: '|' separates the 4 elements from A and 4 from B)
5051+ //
5052+ // Step 1.2: Shuffle.
5053+ // The shuffle pattern for the low part (applied to CL_AB using
5054+ // `shuffle(CL_AB, CH_AB, pattern)`) is {0, 4, 1, 5, 2, 6, 3, 7}.
5055+ // The shuffle pattern for the high part (applied to CH_AB using
5056+ // `shuffle(CL_AB, CH_AB, pattern)`) is {8, 12, 9, 13, 10, 14, 11, 15}.
5057+ // (Indices 0-7 in shuffle refer to CL_AB, 8-15 to CH_AB).
5058+ // This results in:
5059+ // s1_AB_0: A0 B0 A1 B1 A2 B2 A3 B3 (from shuffling CL_AB elements)
5060+ // s1_AB_1: A4 B4 A5 B5 A6 B6 A7 B7 (from shuffling CH_AB elements)
5061+ //
5062+ // Output of Stage 1 / Input to Stage 2 (example for A,B,C,D processing):
5063+ // s1_vregs[0] (from A,B): A0 B0 A1 B1 A2 B2 A3 B3
5064+ // s1_vregs[1] (from A,B): A4 B4 A5 B5 A6 B6 A7 B7
5065+ // s1_vregs[2] (from C,D): C0 D0 C1 D1 C2 D2 C3 D3
5066+ // s1_vregs[3] (from C,D): C4 D4 C5 D5 C6 D6 C7 D7
5067+ // ... (and so on for E,F,G,H into s1_vregs[4-7])
5068+
5069+ // Stage 2: Operates on groups of 4 vregs from Stage 1 output.
5070+ // (e.g., s1_vregs[0], s1_vregs[1], s1_vregs[2], s1_vregs[3])
5071+ //
5072+ // Input to Stage 2 (example processing s1_vregs[0] and s1_vregs[2]):
5073+ // X = s1_vregs[0] = [A0 B0 A1 B1 | A2 B2 A3 B3]
5074+ // Y = s1_vregs[2] = [C0 D0 C1 D1 | C2 D2 C3 D3]
5075+ //
5076+ // Step 2.1: Combine low/high halves.
5077+ // combine_low(X, Y) -> CL_XY: [A0 B0 A1 B1 | C0 D0 C1 D1]
5078+ // combine_high(X, Y) -> CH_XY: [A2 B2 A3 B3 | C2 D2 C3 D3]
5079+ //
5080+ // (Similarly for s1_vregs[1] and s1_vregs[3], let them be X' and Y')
5081+ // combine_low(X', Y') -> CL_X'Y': [A4 B4 A5 B5 | C4 D4 C5 D5]
5082+ // combine_high(X', Y') -> CH_X'Y': [A6 B6 A7 B7 | C6 D6 C7 D7]
5083+ //
5084+ // Step 2.2: Shuffle.
5085+ // The shuffle pattern for the low part (e.g., applied to CL_XY) is {0, 1,
5086+ // 4, 5, 2, 3, 6, 7}. The shuffle pattern for the high part (e.g., applied
5087+ // to CH_XY, effectively) is {8, 9, 12, 13, 10, 11, 14, 15}.
5088+ //
5089+ // This results in (for the first group of 4 input vregs A,B,C,D):
5090+ // s2_vregs[0]: A0 B0 C0 D0 A1 B1 C1 D1 (from shuffling CL_XY elements)
5091+ // s2_vregs[1]: A2 B2 C2 D2 A3 B3 C3 D3 (from shuffling CH_XY elements)
5092+ // s2_vregs[2]: A4 B4 C4 D4 A5 B5 C5 D5 (from shuffling CL_X'Y' elements)
5093+ // s2_vregs[3]: A6 B6 C6 D6 A7 B7 C7 D7 (from shuffling CH_X'Y' elements)
5094+ //
5095+ // Output of Stage 2 / Input to Stage 3:
5096+ // s2_vregs[0]: A0 B0 C0 D0 A1 B1 C1 D1
5097+ // s2_vregs[1]: A2 B2 C2 D2 A3 B3 C3 D3
5098+ // s2_vregs[2]: A4 B4 C4 D4 A5 B5 C5 D5
5099+ // s2_vregs[3]: A6 B6 C6 D6 A7 B7 C7 D7
5100+ // s2_vregs[4]: E0 F0 G0 H0 E1 F1 G1 H1 (from E,F,G,H processing)
5101+ // s2_vregs[5]: E2 F2 G2 H2 E3 F3 G3 H3
5102+ // s2_vregs[6]: E4 F4 G4 H4 E5 F5 G5 H5
5103+ // s2_vregs[7]: E6 F6 G6 H6 E7 F7 G7 H7
5104+
5105+ // Stage 3: Combine results from Stage 2. No shuffle needed after combine.
5106+ // Input to Stage 3 (example for the first two rows of the final transpose):
5107+ // L = s2_vregs[0] = [A0 B0 C0 D0 | A1 B1 C1 D1]
5108+ // R = s2_vregs[4] = [E0 F0 G0 H0 | E1 F1 G1 H1]
5109+ //
5110+ // Step 3.1: Combine low/high halves.
5111+ // combine_low(L, R) -> [A0 B0 C0 D0 | E0 F0 G0 H0] ->
5112+ // Final out0: A0 B0 C0 D0 E0 F0 G0 H0
5113+ // combine_high(L, R) -> [A1 B1 C1 D1 | E1 F1 G1 H1] ->
5114+ // Final out1: A1 B1 C1 D1 E1 F1 G1 H1
5115+ // ... and so on for other pairs from Stage 2 output
5116+ // (e.g. L=s2_vregs[1], R=s2_vregs[5]).
5117+ //
5118+ // This results in the correctly transposed 8x8 block.
5119+
5120+ constexpr int64_t kMajorDimOriginalIdx = 0 ;
5121+ constexpr int64_t kSecondMinorDimOriginalIdx = 1 ;
5122+ constexpr int64_t kMinorMostDimOriginalIdx = 2 ;
5123+
5124+ auto vec_shape = src_ty.getShape ();
5125+ auto major_dim_size = vec_shape[kMajorDimOriginalIdx ];
5126+ auto second_minor_dim_size = vec_shape[kSecondMinorDimOriginalIdx ];
5127+
5128+ if (layout_in.offsets () != LayoutOffsets{0 , 0 }) {
5129+ return transpose_op.emitOpError (" Not implemented: Layout with offset." );
5130+ }
5131+ if (layout_in.implicit_dim () != VectorLayout::ImplicitDim::kNone ) {
5132+ return transpose_op.emitOpError (
5133+ " Not implemented: Layout with implicit dimension." );
5134+ }
5135+
5136+ auto sublane_count = ctx.target_shape [0 ];
5137+ if (second_minor_dim_size % sublane_count != 0 ||
5138+ major_dim_size % sublane_count != 0 ) {
5139+ return transpose_op.emitOpError (
5140+ " Not implemented: Swapping major and second minor dimensions must "
5141+ " result in dimension sizes that are multiples of sublane_count." );
5142+ }
5143+
5144+ if (!layout_in.hasNativeTiling (ctx.target_shape )) {
5145+ return transpose_op.emitOpError (
5146+ " Not implemented: Expected native input tiling." );
5147+ }
5148+ if (layout_in != layout_out) {
5149+ return transpose_op.emitOpError (
5150+ " Not implemented: Expected same input and output layouts." );
5151+ }
5152+ xla::Array<Value> dst_vregs (
5153+ layout_out.tileArrayShape (dst_ty.getShape (), ctx.target_shape ));
5154+
5155+ if (layout_in.bitwidth () != 32 ) {
5156+ return transpose_op.emitOpError (
5157+ " Not implemented: Major-second-minor transpose only supported for "
5158+ " 32-bit vectors. Also, input must be a vector type." );
5159+ }
5160+ if (ctx.target_shape [0 ] != 8 ) {
5161+ return transpose_op.emitOpError (
5162+ " Not implemented: Major-second-minor transpose expects 8 sublanes." );
5163+ }
5164+
5165+ auto vreg_dimensions = src_vregs.dimensions ();
5166+ // Note(mvoz): Slice is a weird word here, This is used for constructing
5167+ // the output vregs - the reason we divide here is because we multiply it
5168+ // back later on to get the correct index into src_vregs, but the reason
5169+ // we cannot just resolve that in our outer loop is because of the nature
5170+ // of a transpose - this dim value goes unmultiplied into the output vregs.
5171+ // effectively, our indexing:
5172+ // {major_dim_slice_idx * sublane_count, second_minor_dim_slice_idx,
5173+ // minor_most_dim_slice_idx} becomes {second_minor_dim_slice_idx *
5174+ // sublane_count, major_dim_slice_idx, minor_most_dim_slice_idx}
5175+ auto num_slices_in_major_dim =
5176+ vreg_dimensions[kMajorDimOriginalIdx ] / sublane_count;
5177+ auto num_slices_in_second_minor_dim =
5178+ vreg_dimensions[kSecondMinorDimOriginalIdx ];
5179+ auto num_slices_in_minor_most_dim =
5180+ vreg_dimensions[kMinorMostDimOriginalIdx ];
5181+
5182+ auto shuffle = [&](Value lhs_vreg, Value rhs_vreg, ArrayRef<int > pattern) {
5183+ auto lhs_vreg_type = lhs_vreg.getType ();
5184+ auto pattern_attr = builder.getDenseI32ArrayAttr (pattern);
5185+ return builder
5186+ .create <tpu::SublaneShuffleOp>(transpose_op.getLoc (), lhs_vreg_type,
5187+ lhs_vreg, rhs_vreg, pattern_attr)
5188+ .getResult ();
5189+ };
5190+
5191+ static constexpr std::array<int , 8 > combine_low_pattern = {0 , 1 , 2 , 3 ,
5192+ 8 , 9 , 10 , 11 };
5193+ static constexpr std::array<int , 8 > combine_high_pattern = {4 , 5 , 6 , 7 ,
5194+ 12 , 13 , 14 , 15 };
5195+
5196+ auto combine_low = [&](Value lhs_vreg, Value rhs_vreg) {
5197+ return shuffle (lhs_vreg, rhs_vreg, combine_low_pattern);
5198+ };
5199+ auto combine_high = [&](Value lhs_vreg, Value rhs_vreg) {
5200+ return shuffle (lhs_vreg, rhs_vreg, combine_high_pattern);
5201+ };
5202+
5203+ // Shuffle patterns for Stage 1
5204+ // Input to shuffle: (combine_low_val, combine_high_val)
5205+ // combine_low_val has A0-A3, B0-B3. Indices 0-7 for shuffle.
5206+ // combine_high_val has A4-A7, B4-B7. Indices 8-15 for shuffle.
5207+ static constexpr std::array<int , 8 > permute_pattern_stage1_low_arr = {
5208+ 0 , 4 , 1 , 5 ,
5209+ 2 , 6 , 3 , 7 }; // Selects from combine_low_val to make A0B0A1B1A2B2A3B3
5210+ static constexpr std::array<int , 8 > permute_pattern_stage1_high_arr = {
5211+ 8 , 12 , 9 , 13 , 10 ,
5212+ 14 , 11 , 15 }; // Selects from combine_high_val to make A4B4A5B5A6B6A7B7
5213+
5214+ // Shuffle patterns for Stage 2
5215+ // Input to shuffle: (CL_XY, CH_XY) from Step 2.1 in comments.
5216+ // CL_XY has A0B0A1B1C0D0C1D1. Indices 0-7 for shuffle.
5217+ // CH_XY has A2B2A3B3C2D2C3D3. Indices 8-15 for shuffle.
5218+ static constexpr std::array<int , 8 > permute_pattern_stage2_low_arr = {
5219+ 0 , 1 , 4 , 5 , 2 , 3 , 6 , 7 }; // Selects from CL_XY to make A0B0C0D0A1B1C1D1
5220+ static constexpr std::array<int , 8 > permute_pattern_stage2_high_arr = {
5221+ 8 , 9 , 12 , 13 ,
5222+ 10 , 11 , 14 , 15 }; // Selects from CH_XY to make A2B2C2D2A3B3C3D3
5223+
5224+ for (int major_dim_slice_idx = 0 ;
5225+ major_dim_slice_idx < num_slices_in_major_dim; ++major_dim_slice_idx) {
5226+ for (int second_minor_dim_slice_idx = 0 ;
5227+ second_minor_dim_slice_idx < num_slices_in_second_minor_dim;
5228+ ++second_minor_dim_slice_idx) {
5229+ for (int minor_most_dim_slice_idx = 0 ;
5230+ minor_most_dim_slice_idx < num_slices_in_minor_most_dim;
5231+ ++minor_most_dim_slice_idx) {
5232+ // STAGE 1!
5233+ std::array<Value, 8 >
5234+ stage1_output_vregs; // Stores s1_vregs from comments
5235+ constexpr int num_pairs_stage1 =
5236+ 4 ; // Processes 4 pairs of vregs (A,B), (C,D), (E,F), (G,H)
5237+
5238+ for (int i = 0 ; i < num_pairs_stage1; ++i) {
5239+ Value first_vreg = src_vregs (
5240+ {(2 * i) + (sublane_count * major_dim_slice_idx),
5241+ second_minor_dim_slice_idx, minor_most_dim_slice_idx});
5242+ Value second_vreg = src_vregs (
5243+ {(2 * i) + (sublane_count * major_dim_slice_idx) + 1 ,
5244+ second_minor_dim_slice_idx, minor_most_dim_slice_idx});
5245+
5246+ auto combined_low_val = combine_low (first_vreg, second_vreg);
5247+ auto combined_high_val = combine_high (first_vreg, second_vreg);
5248+
5249+ stage1_output_vregs[2 * i] =
5250+ shuffle (combined_low_val, combined_high_val,
5251+ permute_pattern_stage1_low_arr);
5252+ stage1_output_vregs[2 * i + 1 ] =
5253+ shuffle (combined_low_val, combined_high_val,
5254+ permute_pattern_stage1_high_arr);
5255+ }
5256+
5257+ // STAGE 2!
5258+ std::array<Value, 8 >
5259+ stage2_output_vregs; // Stores s2_vregs from comments
5260+ constexpr int num_pairs_stage2 =
5261+ 4 ; // Processes 4 pairs of vregs from stage1_output_vregs
5262+
5263+ for (int i = 0 ; i < num_pairs_stage2; ++i) {
5264+ // Determine the indices for the input pair from
5265+ // stage1_output_vregs. The 4 pairs processed in this stage are:
5266+ // i=0: (s1_vregs[0], s1_vregs[2])
5267+ // i=1: (s1_vregs[1], s1_vregs[3])
5268+ // i=2: (s1_vregs[4], s1_vregs[6])
5269+ // i=3: (s1_vregs[5], s1_vregs[7])
5270+ int s1_lhs_idx = (i / 2 ) * 4 + (i % 2 );
5271+ int s1_rhs_idx = s1_lhs_idx + 2 ;
5272+
5273+ Value s1_lhs_vreg = stage1_output_vregs[s1_lhs_idx];
5274+ Value s1_rhs_vreg = stage1_output_vregs[s1_rhs_idx];
5275+
5276+ auto combined_low_val = combine_low (s1_lhs_vreg, s1_rhs_vreg);
5277+ auto combined_high_val = combine_high (s1_lhs_vreg, s1_rhs_vreg);
5278+
5279+ // Determine the output indices for stage2_output_vregs.
5280+ // Each pair from Stage 1 produces a pair of vregs for Stage 2.
5281+ // Results are stored pair-wise:
5282+ // i=0 -> s2_vregs[0], s2_vregs[1]
5283+ // i=1 -> s2_vregs[2], s2_vregs[3]
5284+ // i=2 -> s2_vregs[4], s2_vregs[5]
5285+ // i=3 -> s2_vregs[6], s2_vregs[7]
5286+ int s2_out_idx_base = 2 * i;
5287+
5288+ stage2_output_vregs[s2_out_idx_base] =
5289+ shuffle (combined_low_val, combined_high_val,
5290+ permute_pattern_stage2_low_arr);
5291+ stage2_output_vregs[s2_out_idx_base + 1 ] =
5292+ shuffle (combined_low_val, combined_high_val,
5293+ permute_pattern_stage2_high_arr);
5294+ }
5295+
5296+ // STAGE 3! Combine results from stage 2.
5297+ std::array<int64_t , 3 > output_idx_parts{
5298+ second_minor_dim_slice_idx * sublane_count, major_dim_slice_idx,
5299+ minor_most_dim_slice_idx};
5300+
5301+ constexpr int num_final_combines =
5302+ 4 ; // Corresponds to s2_vregs[0]..s2_vregs[3] pairing with
5303+ // s2_vregs[4]..s2_vregs[7]
5304+ for (int i = 0 ; i < num_final_combines; ++i) {
5305+ Value lhs = stage2_output_vregs[i]; // e.g., s2_ABCD_0
5306+ Value rhs = stage2_output_vregs[i + 4 ]; // e.g., s2_EFGH_0
5307+ auto final_combined_low = combine_low (lhs, rhs);
5308+ auto final_combined_high = combine_high (lhs, rhs);
5309+
5310+ dst_vregs (output_idx_parts) = final_combined_low;
5311+ output_idx_parts[0 ] += 1 ;
5312+ dst_vregs (output_idx_parts) = final_combined_high;
5313+ output_idx_parts[0 ] += 1 ;
5314+ }
5315+ }
5316+ }
5317+ }
5318+ auto assembled =
5319+ assemble (builder, dst_ty, layout_out, dst_vregs, ctx.target_shape );
5320+ transpose_op.getOperation ()->replaceAllUsesWith (assembled);
5321+ transpose_op.erase ();
5322+ return success ();
50205323 }
5324+
50215325 {
50225326 SmallVector<int64_t > p (permutation);
50235327 p[rank - 2 ] = rank - 2 ;
0 commit comments