Skip to content

Commit 7f7510b

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Byte-granularity dynamic gathers
PiperOrigin-RevId: 762076385
1 parent 1aaec81 commit 7f7510b

File tree

4 files changed

+84
-27
lines changed

4 files changed

+84
-27
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,20 +470,26 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, SameOperandsAndResultS
470470
let description = [{
471471
Gathers elements from `source` using `indices`.
472472

473-
Given a shape `N0 x N1 x ...`, `output[i0, i1, ...]` is given by
474-
`input[j0, j1, ...]` where `jn = indices[i0, i1, ...] mod Ni` for
475-
`n = dimension` and `jn = in` otherwise.
473+
The specified `dimensions` of `source` are collapsed together and indexed by
474+
`indices`.
476475

477-
Similar to `np.take_along_axis`, except that OOB indices wrap.
476+
Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by
477+
`collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where
478+
- `collapsed_source` is the result of collapsing `dimensions` of `source`
479+
into a new trailing dimension of size `M`.
480+
- `jk` is the subsequence of `in` for `n` not in `dimensions`.
481+
482+
When a single dimension is specified, this is similar to
483+
`np.take_along_axis`, except that OOB indices wrap.
478484
}];
479485
let arguments = (ins
480486
AnyVectorOfNonZeroRank:$source,
481487
VectorOfNonZeroRankOf<[AnyInteger]>:$indices,
482-
I32Attr:$dimension
488+
DenseI32ArrayAttr:$dimensions
483489
);
484490
let results = (outs AnyVectorOfNonZeroRank:$output);
485491
let assemblyFormat = [{
486-
$source `[` $indices `]` `in` $dimension attr-dict
492+
$source `[` $indices `]` `in` $dimensions attr-dict
487493
`:` type($source) `,` type($indices) `->` type($output)
488494
}];
489495
let hasVerifier = 1;

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,8 +1449,10 @@ LogicalResult DynamicGatherOp::verify() {
14491449
if (getIndices().getType().getShape() != getIndices().getType().getShape()) {
14501450
return emitOpError("Expected indices and result shapes must match");
14511451
}
1452-
if (!getIndices().getType().getElementType().isInteger(32)) {
1453-
return emitOpError("Not implemented: Only i32 indices supported");
1452+
for (int32_t d : getDimensions()) {
1453+
if (d < 0 || d >= getSource().getType().getRank()) {
1454+
return emitOpError("Invalid dimension specified: ") << d;
1455+
}
14541456
}
14551457
return success();
14561458
}

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3147,10 +3147,15 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
31473147
auto dy_gather_op = cast<tpu::DynamicGatherOp>(op);
31483148

31493149
// TODO(jevinjiang): we need to think harder for general vector shape.
3150-
if (dy_gather_op.getType().getShape() !=
3151-
ArrayRef<int64_t>(ctx.target_shape)) {
3150+
if (!(dy_gather_op.getType().getElementTypeBitWidth() == 32 &&
3151+
dy_gather_op.getType().getShape() ==
3152+
ArrayRef<int64_t>(ctx.target_shape)) &&
3153+
!(dy_gather_op.getType().getElementTypeBitWidth() == 8 &&
3154+
dy_gather_op.getType().getShape() ==
3155+
ArrayRef<int64_t>{4 * ctx.target_shape[0], ctx.target_shape[1]})) {
31523156
return op.emitOpError(
3153-
"Not implemented: DynamicGatherOp only supports 32-bit VREG shape");
3157+
"Not implemented: DynamicGatherOp only supports 8- or 32-bit VREG "
3158+
"shape");
31543159
}
31553160

31563161
if (src_layout != out_layout || idx_layout != out_layout) {
@@ -3159,7 +3164,7 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
31593164
"result");
31603165
}
31613166

3162-
if (!out_layout.hasNaturalTopology(ctx.target_shape)) {
3167+
if (!out_layout.hasNativeTiling(ctx.target_shape)) {
31633168
return op.emitOpError(
31643169
"Not implemented: unsupported layout for DynamicGatherOp");
31653170
}
@@ -3177,11 +3182,54 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
31773182
TPU_ASSERT_EQ_OP(src_vregs.dimensions(), idx_vregs.dimensions());
31783183
TPU_ASSERT_EQ_OP(src_vregs.num_elements(), 1);
31793184

3185+
Location loc = dy_gather_op.getLoc();
3186+
SmallVector<int32_t> dimensions(dy_gather_op.getDimensions());
3187+
if (dy_gather_op.getType().getElementTypeBitWidth() == 8) {
3188+
if (dy_gather_op.getDimensions() != ArrayRef<int32_t>{0}) {
3189+
return dy_gather_op.emitOpError(
3190+
"Not implemented: 8-bit dynamic gather only supported along "
3191+
"dimension 0");
3192+
}
3193+
// Vreg shape is 8x128x4, and lowering only supports dimensions == {2, 0},
3194+
// i.e. byte index is in the upper bits and sublane index in the lower bits.
3195+
// However, the input indices effectively have sublane index in the upper
3196+
// bits and byte index in the lower bits.
3197+
VectorType i32_vreg_ty =
3198+
getNativeVregType(builder.getI32Type(), ctx.target_shape);
3199+
VectorType i8_vreg_ty =
3200+
getNativeVregType(builder.getI8Type(), ctx.target_shape);
3201+
idx_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
3202+
const int sublane_bits = llvm::Log2_64(ctx.target_shape[0]);
3203+
const int byte_bits = 2;
3204+
// This check ensures the shifting below does not change anything across
3205+
// bytes for relevant (byte and sublane) bits. Lets us mask just once.
3206+
CHECK_LE(sublane_bits + byte_bits + std::max(byte_bits, sublane_bits), 8);
3207+
// Zero out the high bits that specify neither byte nor index (they might
3208+
// not be zero since op semantics allow wrapping).
3209+
Value mask = getFullVector(
3210+
builder, loc, i8_vreg_ty,
3211+
builder.getI8IntegerAttr((1 << (byte_bits + sublane_bits)) - 1));
3212+
*v = builder.create<arith::AndIOp>(loc, mask, *v);
3213+
*v = builder.create<tpu::BitcastVregOp>(loc, i32_vreg_ty, *v);
3214+
Value shifted_byte = builder.create<arith::ShLIOp>(
3215+
loc, *v,
3216+
getFullVector(builder, loc, i32_vreg_ty,
3217+
builder.getI32IntegerAttr(sublane_bits)));
3218+
Value shifted_sublane = builder.create<arith::ShRUIOp>(
3219+
loc, *v,
3220+
getFullVector(builder, loc, i32_vreg_ty,
3221+
builder.getI32IntegerAttr(byte_bits)));
3222+
*v = builder.create<arith::OrIOp>(loc, shifted_byte, shifted_sublane);
3223+
*v = builder.create<tpu::BitcastVregOp>(loc, i8_vreg_ty, *v);
3224+
});
3225+
dimensions = SmallVector<int32_t>{2, 0};
3226+
}
3227+
31803228
xla::Array<Value> out_vregs(src_vregs.dimensions());
31813229
out_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
3182-
*v = builder.create<tpu::DynamicGatherOp>(
3183-
op.getLoc(), src_vregs(idxs).getType(), src_vregs(idxs),
3184-
idx_vregs(idxs), dy_gather_op.getDimension());
3230+
*v = builder.create<tpu::DynamicGatherOp>(loc, src_vregs(idxs).getType(),
3231+
src_vregs(idxs), idx_vregs(idxs),
3232+
dimensions);
31853233
});
31863234

31873235
dy_gather_op.replaceAllUsesWith(

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -965,22 +965,23 @@ class VectorLayoutInferer {
965965
}
966966

967967
LogicalResult infer(tpu::DynamicGatherOp op) {
968-
if (op.getType().getShape() != ArrayRef<int64_t>(target_shape_) &&
969-
op.getType().getElementTypeBitWidth() != 32) {
970-
return op.emitOpError(
971-
"Not implemented: DynamicGatherOp only supports 32-bit VREG shape");
972-
}
973-
if (op.getDimension() != 0 && op.getDimension() != 1) {
974-
return op.emitOpError(
975-
"Not implemented: Only dimension 0 and 1 are supported");
976-
}
977968
// TODO(jevinjiang): we could preserve some offsets such as replicated
978969
// offset but since we are forcing all operands and result to be the same
979970
// layout, we can set all offsets to zero for now. Also maybe we should
980971
// consider adding this to elementwise rule.
981-
auto layout = VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_,
982-
ImplicitDim::kNone);
983-
setLayout(op, {layout, layout}, layout);
972+
if (op.getType().getShape() == ArrayRef<int64_t>(target_shape_) &&
973+
op.getType().getElementTypeBitWidth() == 32) {
974+
VectorLayout layout(kNativeBitwidth, {0, 0}, default_tiling_,
975+
ImplicitDim::kNone);
976+
setLayout(op, {layout, layout}, layout);
977+
} else if (op.getIndices().getType().getShape() ==
978+
ArrayRef<int64_t>{4 * target_shape_[0], target_shape_[1]} &&
979+
op.getType().getElementTypeBitWidth() == 8) {
980+
VectorLayout layout(8, {0, 0}, nativeTiling(8), ImplicitDim::kNone);
981+
setLayout(op, {layout, layout}, layout);
982+
} else {
983+
return op.emitOpError("Not implemented");
984+
}
984985
return success();
985986
}
986987

0 commit comments

Comments
 (0)