Skip to content

Commit 62c46ff

Browse files
Add initial support aligned jnp.swapaxes on major/minor dims
Next steps: - non-tile aligned - Clean up fn and utilize it for general changeTiling PiperOrigin-RevId: 761731600
1 parent 8da86ea commit 62c46ff

4 files changed

Lines changed: 375 additions & 13 deletions

File tree

jax/_src/pallas/mosaic/lowering.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2383,7 +2383,9 @@ def _gather_lowering_rule(
23832383

23842384
@register_lowering_rule(lax.transpose_p)
23852385
def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):
2386-
if permutation != (1, 0):
2386+
minormost_transpose = (1, 0)
2387+
untiled_tiled_swap = (1, 0, 2)
2388+
if permutation not in (minormost_transpose, untiled_tiled_swap):
23872389
raise NotImplementedError
23882390
out_type = aval_to_ir_type(
23892391
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 306 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5013,11 +5013,315 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
50135013
ctx.target_shape));
50145014
ArrayRef<int64_t> permutation = transpose_op.getPermutation();
50155015
const auto tile_perm = permutation.take_back(2);
5016+
5017+
// Major minor pemute
50165018
if (tile_perm != ArrayRef<int64_t>{rank - 2, rank - 1} &&
50175019
tile_perm != ArrayRef<int64_t>{rank - 1, rank - 2}) {
5018-
return transpose_op->emitOpError(
5019-
"Not implemented: Unsupported permutation");
5020+
// This is a 3 stage algorithm that uses combinations and shuffles
5021+
// to do a transposition of an 8x8 block of sublanes.
5022+
// In the following algorithm description, A, B, ..., H represent 8
5023+
// distinct input vregs that form an 8x8 block of data
5024+
// to be transposed. In our notation, B2 identifies the third
5025+
// sublane (2) of the second vreg (B)".
5026+
//
5027+
//
5028+
// If we think of each starting input vreg as a row in an 8x8 block of
5029+
// elements:
5030+
// A: A0 A1 A2 A3 A4 A5 A6 A7
5031+
// B: B0 B1 B2 B3 B4 B5 B6 B7
5032+
// ...
5033+
// H: H0 H1 H2 H3 H4 H5 H6 H7
5034+
//
5035+
// The goal is to transpose this block, so the output vregs are:
5036+
// out0: A0 B0 C0 D0 E0 F0 G0 H0
5037+
// out1: A1 B1 C1 D1 E1 F1 G1 H1
5038+
// ...
5039+
// out7: A7 B7 C7 D7 E7 F7 G7 H7
5040+
//
5041+
// Stage 1: Operates on pairs of input vregs (e.g., A and B).
5042+
//
5043+
// Input to Stage 1 (example pair A, B):
5044+
// A: A0 A1 A2 A3 A4 A5 A6 A7
5045+
// B: B0 B1 B2 B3 B4 B5 B6 B7
5046+
//
5047+
// Step 1.1: Combine low/high halves.
5048+
// combine_low(A, B) -> CL_AB: [A0 A1 A2 A3 | B0 B1 B2 B3] (8 elements)
5049+
// combine_high(A, B) -> CH_AB: [A4 A5 A6 A7 | B4 B5 B6 B7] (8 elements)
5050+
// (Notation: '|' separates the 4 elements from A and 4 from B)
5051+
//
5052+
// Step 1.2: Shuffle.
5053+
// The shuffle pattern for the low part (applied to CL_AB using
5054+
// `shuffle(CL_AB, CH_AB, pattern)`) is {0, 4, 1, 5, 2, 6, 3, 7}.
5055+
// The shuffle pattern for the high part (applied to CH_AB using
5056+
// `shuffle(CL_AB, CH_AB, pattern)`) is {8, 12, 9, 13, 10, 14, 11, 15}.
5057+
// (Indices 0-7 in shuffle refer to CL_AB, 8-15 to CH_AB).
5058+
// This results in:
5059+
// s1_AB_0: A0 B0 A1 B1 A2 B2 A3 B3 (from shuffling CL_AB elements)
5060+
// s1_AB_1: A4 B4 A5 B5 A6 B6 A7 B7 (from shuffling CH_AB elements)
5061+
//
5062+
// Output of Stage 1 / Input to Stage 2 (example for A,B,C,D processing):
5063+
// s1_vregs[0] (from A,B): A0 B0 A1 B1 A2 B2 A3 B3
5064+
// s1_vregs[1] (from A,B): A4 B4 A5 B5 A6 B6 A7 B7
5065+
// s1_vregs[2] (from C,D): C0 D0 C1 D1 C2 D2 C3 D3
5066+
// s1_vregs[3] (from C,D): C4 D4 C5 D5 C6 D6 C7 D7
5067+
// ... (and so on for E,F,G,H into s1_vregs[4-7])
5068+
5069+
// Stage 2: Operates on groups of 4 vregs from Stage 1 output.
5070+
// (e.g., s1_vregs[0], s1_vregs[1], s1_vregs[2], s1_vregs[3])
5071+
//
5072+
// Input to Stage 2 (example processing s1_vregs[0] and s1_vregs[2]):
5073+
// X = s1_vregs[0] = [A0 B0 A1 B1 | A2 B2 A3 B3]
5074+
// Y = s1_vregs[2] = [C0 D0 C1 D1 | C2 D2 C3 D3]
5075+
//
5076+
// Step 2.1: Combine low/high halves.
5077+
// combine_low(X, Y) -> CL_XY: [A0 B0 A1 B1 | C0 D0 C1 D1]
5078+
// combine_high(X, Y) -> CH_XY: [A2 B2 A3 B3 | C2 D2 C3 D3]
5079+
//
5080+
// (Similarly for s1_vregs[1] and s1_vregs[3], let them be X' and Y')
5081+
// combine_low(X', Y') -> CL_X'Y': [A4 B4 A5 B5 | C4 D4 C5 D5]
5082+
// combine_high(X', Y') -> CH_X'Y': [A6 B6 A7 B7 | C6 D6 C7 D7]
5083+
//
5084+
// Step 2.2: Shuffle.
5085+
// The shuffle pattern for the low part (e.g., applied to CL_XY) is {0, 1,
5086+
// 4, 5, 2, 3, 6, 7}. The shuffle pattern for the high part (e.g., applied
5087+
// to CH_XY, effectively) is {8, 9, 12, 13, 10, 11, 14, 15}.
5088+
//
5089+
// This results in (for the first group of 4 input vregs A,B,C,D):
5090+
// s2_vregs[0]: A0 B0 C0 D0 A1 B1 C1 D1 (from shuffling CL_XY elements)
5091+
// s2_vregs[1]: A2 B2 C2 D2 A3 B3 C3 D3 (from shuffling CH_XY elements)
5092+
// s2_vregs[2]: A4 B4 C4 D4 A5 B5 C5 D5 (from shuffling CL_X'Y' elements)
5093+
// s2_vregs[3]: A6 B6 C6 D6 A7 B7 C7 D7 (from shuffling CH_X'Y' elements)
5094+
//
5095+
// Output of Stage 2 / Input to Stage 3:
5096+
// s2_vregs[0]: A0 B0 C0 D0 A1 B1 C1 D1
5097+
// s2_vregs[1]: A2 B2 C2 D2 A3 B3 C3 D3
5098+
// s2_vregs[2]: A4 B4 C4 D4 A5 B5 C5 D5
5099+
// s2_vregs[3]: A6 B6 C6 D6 A7 B7 C7 D7
5100+
// s2_vregs[4]: E0 F0 G0 H0 E1 F1 G1 H1 (from E,F,G,H processing)
5101+
// s2_vregs[5]: E2 F2 G2 H2 E3 F3 G3 H3
5102+
// s2_vregs[6]: E4 F4 G4 H4 E5 F5 G5 H5
5103+
// s2_vregs[7]: E6 F6 G6 H6 E7 F7 G7 H7
5104+
5105+
// Stage 3: Combine results from Stage 2. No shuffle needed after combine.
5106+
// Input to Stage 3 (example for the first two rows of the final transpose):
5107+
// L = s2_vregs[0] = [A0 B0 C0 D0 | A1 B1 C1 D1]
5108+
// R = s2_vregs[4] = [E0 F0 G0 H0 | E1 F1 G1 H1]
5109+
//
5110+
// Step 3.1: Combine low/high halves.
5111+
// combine_low(L, R) -> [A0 B0 C0 D0 | E0 F0 G0 H0] ->
5112+
// Final out0: A0 B0 C0 D0 E0 F0 G0 H0
5113+
// combine_high(L, R) -> [A1 B1 C1 D1 | E1 F1 G1 H1] ->
5114+
// Final out1: A1 B1 C1 D1 E1 F1 G1 H1
5115+
// ... and so on for other pairs from Stage 2 output
5116+
// (e.g. L=s2_vregs[1], R=s2_vregs[5]).
5117+
//
5118+
// This results in the correctly transposed 8x8 block.
5119+
5120+
constexpr int64_t kMajorDimOriginalIdx = 0;
5121+
constexpr int64_t kSecondMinorDimOriginalIdx = 1;
5122+
constexpr int64_t kMinorMostDimOriginalIdx = 2;
5123+
5124+
auto vec_shape = src_ty.getShape();
5125+
auto major_dim_size = vec_shape[kMajorDimOriginalIdx];
5126+
auto second_minor_dim_size = vec_shape[kSecondMinorDimOriginalIdx];
5127+
5128+
if (layout_in.offsets() != LayoutOffsets{0, 0}) {
5129+
return transpose_op.emitOpError("Not implemented: Layout with offset.");
5130+
}
5131+
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
5132+
return transpose_op.emitOpError(
5133+
"Not implemented: Layout with implicit dimension.");
5134+
}
5135+
5136+
auto sublane_count = ctx.target_shape[0];
5137+
if (second_minor_dim_size % sublane_count != 0 ||
5138+
major_dim_size % sublane_count != 0) {
5139+
return transpose_op.emitOpError(
5140+
"Not implemented: Swapping major and second minor dimensions must "
5141+
"result in dimension sizes that are multiples of sublane_count.");
5142+
}
5143+
5144+
if (!layout_in.hasNativeTiling(ctx.target_shape)) {
5145+
return transpose_op.emitOpError(
5146+
"Not implemented: Expected native input tiling.");
5147+
}
5148+
if (layout_in != layout_out) {
5149+
return transpose_op.emitOpError(
5150+
"Not implemented: Expected same input and output layouts.");
5151+
}
5152+
xla::Array<Value> dst_vregs(
5153+
layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape));
5154+
5155+
if (layout_in.bitwidth() != 32) {
5156+
return transpose_op.emitOpError(
5157+
"Not implemented: Major-second-minor transpose only supported for "
5158+
"32-bit vectors. Also, input must be a vector type.");
5159+
}
5160+
if (ctx.target_shape[0] != 8) {
5161+
return transpose_op.emitOpError(
5162+
"Not implemented: Major-second-minor transpose expects 8 sublanes.");
5163+
}
5164+
5165+
auto vreg_dimensions = src_vregs.dimensions();
5166+
// Note(mvoz): Slice is a weird word here, This is used for constructing
5167+
// the output vregs - the reason we divide here is because we multiply it
5168+
// back later on to get the correct index into src_vregs, but the reason
5169+
// we cannot just resolve that in our outer loop is because of the nature
5170+
// of a transpose - this dim value goes unmultiplied into the output vregs.
5171+
// effectively, our indexing:
5172+
// {major_dim_slice_idx * sublane_count, second_minor_dim_slice_idx,
5173+
// minor_most_dim_slice_idx} becomes {second_minor_dim_slice_idx *
5174+
// sublane_count, major_dim_slice_idx, minor_most_dim_slice_idx}
5175+
auto num_slices_in_major_dim =
5176+
vreg_dimensions[kMajorDimOriginalIdx] / sublane_count;
5177+
auto num_slices_in_second_minor_dim =
5178+
vreg_dimensions[kSecondMinorDimOriginalIdx];
5179+
auto num_slices_in_minor_most_dim =
5180+
vreg_dimensions[kMinorMostDimOriginalIdx];
5181+
5182+
auto shuffle = [&](Value lhs_vreg, Value rhs_vreg, ArrayRef<int> pattern) {
5183+
auto lhs_vreg_type = lhs_vreg.getType();
5184+
auto pattern_attr = builder.getDenseI32ArrayAttr(pattern);
5185+
return builder
5186+
.create<tpu::SublaneShuffleOp>(transpose_op.getLoc(), lhs_vreg_type,
5187+
lhs_vreg, rhs_vreg, pattern_attr)
5188+
.getResult();
5189+
};
5190+
5191+
static constexpr std::array<int, 8> combine_low_pattern = {0, 1, 2, 3,
5192+
8, 9, 10, 11};
5193+
static constexpr std::array<int, 8> combine_high_pattern = {4, 5, 6, 7,
5194+
12, 13, 14, 15};
5195+
5196+
auto combine_low = [&](Value lhs_vreg, Value rhs_vreg) {
5197+
return shuffle(lhs_vreg, rhs_vreg, combine_low_pattern);
5198+
};
5199+
auto combine_high = [&](Value lhs_vreg, Value rhs_vreg) {
5200+
return shuffle(lhs_vreg, rhs_vreg, combine_high_pattern);
5201+
};
5202+
5203+
// Shuffle patterns for Stage 1
5204+
// Input to shuffle: (combine_low_val, combine_high_val)
5205+
// combine_low_val has A0-A3, B0-B3. Indices 0-7 for shuffle.
5206+
// combine_high_val has A4-A7, B4-B7. Indices 8-15 for shuffle.
5207+
static constexpr std::array<int, 8> permute_pattern_stage1_low_arr = {
5208+
0, 4, 1, 5,
5209+
2, 6, 3, 7}; // Selects from combine_low_val to make A0B0A1B1A2B2A3B3
5210+
static constexpr std::array<int, 8> permute_pattern_stage1_high_arr = {
5211+
8, 12, 9, 13, 10,
5212+
14, 11, 15}; // Selects from combine_high_val to make A4B4A5B5A6B6A7B7
5213+
5214+
// Shuffle patterns for Stage 2
5215+
// Input to shuffle: (CL_XY, CH_XY) from Step 2.1 in comments.
5216+
// CL_XY has A0B0A1B1C0D0C1D1. Indices 0-7 for shuffle.
5217+
// CH_XY has A2B2A3B3C2D2C3D3. Indices 8-15 for shuffle.
5218+
static constexpr std::array<int, 8> permute_pattern_stage2_low_arr = {
5219+
0, 1, 4, 5, 2, 3, 6, 7}; // Selects from CL_XY to make A0B0C0D0A1B1C1D1
5220+
static constexpr std::array<int, 8> permute_pattern_stage2_high_arr = {
5221+
8, 9, 12, 13,
5222+
10, 11, 14, 15}; // Selects from CH_XY to make A2B2C2D2A3B3C3D3
5223+
5224+
for (int major_dim_slice_idx = 0;
5225+
major_dim_slice_idx < num_slices_in_major_dim; ++major_dim_slice_idx) {
5226+
for (int second_minor_dim_slice_idx = 0;
5227+
second_minor_dim_slice_idx < num_slices_in_second_minor_dim;
5228+
++second_minor_dim_slice_idx) {
5229+
for (int minor_most_dim_slice_idx = 0;
5230+
minor_most_dim_slice_idx < num_slices_in_minor_most_dim;
5231+
++minor_most_dim_slice_idx) {
5232+
// STAGE 1!
5233+
std::array<Value, 8>
5234+
stage1_output_vregs; // Stores s1_vregs from comments
5235+
constexpr int num_pairs_stage1 =
5236+
4; // Processes 4 pairs of vregs (A,B), (C,D), (E,F), (G,H)
5237+
5238+
for (int i = 0; i < num_pairs_stage1; ++i) {
5239+
Value first_vreg = src_vregs(
5240+
{(2 * i) + (sublane_count * major_dim_slice_idx),
5241+
second_minor_dim_slice_idx, minor_most_dim_slice_idx});
5242+
Value second_vreg = src_vregs(
5243+
{(2 * i) + (sublane_count * major_dim_slice_idx) + 1,
5244+
second_minor_dim_slice_idx, minor_most_dim_slice_idx});
5245+
5246+
auto combined_low_val = combine_low(first_vreg, second_vreg);
5247+
auto combined_high_val = combine_high(first_vreg, second_vreg);
5248+
5249+
stage1_output_vregs[2 * i] =
5250+
shuffle(combined_low_val, combined_high_val,
5251+
permute_pattern_stage1_low_arr);
5252+
stage1_output_vregs[2 * i + 1] =
5253+
shuffle(combined_low_val, combined_high_val,
5254+
permute_pattern_stage1_high_arr);
5255+
}
5256+
5257+
// STAGE 2!
5258+
std::array<Value, 8>
5259+
stage2_output_vregs; // Stores s2_vregs from comments
5260+
constexpr int num_pairs_stage2 =
5261+
4; // Processes 4 pairs of vregs from stage1_output_vregs
5262+
5263+
for (int i = 0; i < num_pairs_stage2; ++i) {
5264+
// Determine the indices for the input pair from
5265+
// stage1_output_vregs. The 4 pairs processed in this stage are:
5266+
// i=0: (s1_vregs[0], s1_vregs[2])
5267+
// i=1: (s1_vregs[1], s1_vregs[3])
5268+
// i=2: (s1_vregs[4], s1_vregs[6])
5269+
// i=3: (s1_vregs[5], s1_vregs[7])
5270+
int s1_lhs_idx = (i / 2) * 4 + (i % 2);
5271+
int s1_rhs_idx = s1_lhs_idx + 2;
5272+
5273+
Value s1_lhs_vreg = stage1_output_vregs[s1_lhs_idx];
5274+
Value s1_rhs_vreg = stage1_output_vregs[s1_rhs_idx];
5275+
5276+
auto combined_low_val = combine_low(s1_lhs_vreg, s1_rhs_vreg);
5277+
auto combined_high_val = combine_high(s1_lhs_vreg, s1_rhs_vreg);
5278+
5279+
// Determine the output indices for stage2_output_vregs.
5280+
// Each pair from Stage 1 produces a pair of vregs for Stage 2.
5281+
// Results are stored pair-wise:
5282+
// i=0 -> s2_vregs[0], s2_vregs[1]
5283+
// i=1 -> s2_vregs[2], s2_vregs[3]
5284+
// i=2 -> s2_vregs[4], s2_vregs[5]
5285+
// i=3 -> s2_vregs[6], s2_vregs[7]
5286+
int s2_out_idx_base = 2 * i;
5287+
5288+
stage2_output_vregs[s2_out_idx_base] =
5289+
shuffle(combined_low_val, combined_high_val,
5290+
permute_pattern_stage2_low_arr);
5291+
stage2_output_vregs[s2_out_idx_base + 1] =
5292+
shuffle(combined_low_val, combined_high_val,
5293+
permute_pattern_stage2_high_arr);
5294+
}
5295+
5296+
// STAGE 3! Combine results from stage 2.
5297+
std::array<int64_t, 3> output_idx_parts{
5298+
second_minor_dim_slice_idx * sublane_count, major_dim_slice_idx,
5299+
minor_most_dim_slice_idx};
5300+
5301+
constexpr int num_final_combines =
5302+
4; // Corresponds to s2_vregs[0]..s2_vregs[3] pairing with
5303+
// s2_vregs[4]..s2_vregs[7]
5304+
for (int i = 0; i < num_final_combines; ++i) {
5305+
Value lhs = stage2_output_vregs[i]; // e.g., s2_ABCD_0
5306+
Value rhs = stage2_output_vregs[i + 4]; // e.g., s2_EFGH_0
5307+
auto final_combined_low = combine_low(lhs, rhs);
5308+
auto final_combined_high = combine_high(lhs, rhs);
5309+
5310+
dst_vregs(output_idx_parts) = final_combined_low;
5311+
output_idx_parts[0] += 1;
5312+
dst_vregs(output_idx_parts) = final_combined_high;
5313+
output_idx_parts[0] += 1;
5314+
}
5315+
}
5316+
}
5317+
}
5318+
auto assembled =
5319+
assemble(builder, dst_ty, layout_out, dst_vregs, ctx.target_shape);
5320+
transpose_op.getOperation()->replaceAllUsesWith(assembled);
5321+
transpose_op.erase();
5322+
return success();
50205323
}
5324+
50215325
{
50225326
SmallVector<int64_t> p(permutation);
50235327
p[rank - 2] = rank - 2;

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,17 +1680,27 @@ class VectorLayoutInferer {
16801680
auto src_ty = op.getSourceVectorType();
16811681
TPU_CHECK_OP(permutation.size() == src_ty.getRank(),
16821682
"Transpose permutation has incorrect rank");
1683-
for (auto dim : permutation.drop_back(2)) {
1684-
TPU_CHECK_OP(dim < src_ty.getRank() - 2,
1685-
"Unsupported transpose permutation - minor dims into major");
1686-
}
1687-
for (auto dim : permutation.take_back(2)) {
1688-
TPU_CHECK_OP(dim >= src_ty.getRank() - 2,
1689-
"Unsupported transpose permutation - major dims into minor");
1683+
bool untiled_tiled_swap = false;
1684+
// TODO(mvoz): Expand to more general cases. b/419268277
1685+
if (permutation.size() == 3 && permutation[0] == 1 && permutation[1] == 0) {
1686+
untiled_tiled_swap = true;
1687+
} else {
1688+
for (auto dim : permutation.drop_back(2)) {
1689+
TPU_CHECK_OP(dim < src_ty.getRank() - 2,
1690+
"Unsupported transpose permutation - minor dims into "
1691+
"major > 3 dimensions");
1692+
}
1693+
for (auto dim : permutation.take_back(2)) {
1694+
TPU_CHECK_OP(dim >= src_ty.getRank() - 2,
1695+
"Unsupported transpose permutation - major dims into "
1696+
"minor > 3 dimensions");
1697+
}
16901698
}
16911699
Layout required_layout = some_layout;
1692-
// Require native tiling if we're going to use the XLU.
1693-
if (permutation[permutation.size() - 1] == permutation.size() - 2) {
1700+
// Require native tiling if we're going to use the XLU, or doing a
1701+
// major/minor permute.
1702+
if (untiled_tiled_swap ||
1703+
permutation[permutation.size() - 1] == permutation.size() - 2) {
16941704
auto native_tiling = nativeTiling(layout.bitwidth());
16951705
required_layout = VectorLayout(layout.bitwidth(), LayoutOffsets{0, 0},
16961706
native_tiling, ImplicitDim::kNone);

0 commit comments

Comments
 (0)