Skip to content

[Mosaic:TPU] Byte-granularity dynamic gathers #28952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -470,20 +470,26 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, SameOperandsAndResultS
let description = [{
Gathers elements from `source` using `indices`.

Given a shape `N0 x N1 x ...`, `output[i0, i1, ...]` is given by
`input[j0, j1, ...]` where `jn = indices[i0, i1, ...] mod Ni` for
`n = dimension` and `jn = in` otherwise.
The specified `dimensions` of `source` are collapsed together and indexed by
`indices`.

Similar to `np.take_along_axis`, except that OOB indices wrap.
Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by
`collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where
- `collapsed_source` is the result of collapsing `dimensions` of `source`
into a new trailing dimension of size `M`.
- `jk` is the subsequence of `in` for `n` not in `dimensions`.

When a single dimension is specified, this is similar to
`np.take_along_axis`, except that OOB indices wrap.
}];
let arguments = (ins
AnyVectorOfNonZeroRank:$source,
VectorOfNonZeroRankOf<[AnyInteger]>:$indices,
I32Attr:$dimension
DenseI32ArrayAttr:$dimensions
);
let results = (outs AnyVectorOfNonZeroRank:$output);
let assemblyFormat = [{
$source `[` $indices `]` `in` $dimension attr-dict
$source `[` $indices `]` `in` $dimensions attr-dict
`:` type($source) `,` type($indices) `->` type($output)
}];
let hasVerifier = 1;
Expand Down
6 changes: 4 additions & 2 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1449,8 +1449,10 @@ LogicalResult DynamicGatherOp::verify() {
if (getIndices().getType().getShape() != getIndices().getType().getShape()) {
return emitOpError("Expected indices and result shapes must match");
}
if (!getIndices().getType().getElementType().isInteger(32)) {
return emitOpError("Not implemented: Only i32 indices supported");
for (int32_t d : getDimensions()) {
if (d < 0 || d >= getSource().getType().getRank()) {
return emitOpError("Invalid dimension specified: ") << d;
}
}
return success();
}
Expand Down
62 changes: 55 additions & 7 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3147,10 +3147,15 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
auto dy_gather_op = cast<tpu::DynamicGatherOp>(op);

// TODO(jevinjiang): we need to think harder for general vector shape.
if (dy_gather_op.getType().getShape() !=
ArrayRef<int64_t>(ctx.target_shape)) {
if (!(dy_gather_op.getType().getElementTypeBitWidth() == 32 &&
dy_gather_op.getType().getShape() ==
ArrayRef<int64_t>(ctx.target_shape)) &&
!(dy_gather_op.getType().getElementTypeBitWidth() == 8 &&
dy_gather_op.getType().getShape() ==
ArrayRef<int64_t>{4 * ctx.target_shape[0], ctx.target_shape[1]})) {
return op.emitOpError(
"Not implemented: DynamicGatherOp only supports 32-bit VREG shape");
"Not implemented: DynamicGatherOp only supports 8- or 32-bit VREG "
"shape");
}

if (src_layout != out_layout || idx_layout != out_layout) {
Expand All @@ -3159,7 +3164,7 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
"result");
}

if (!out_layout.hasNaturalTopology(ctx.target_shape)) {
if (!out_layout.hasNativeTiling(ctx.target_shape)) {
return op.emitOpError(
"Not implemented: unsupported layout for DynamicGatherOp");
}
Expand All @@ -3177,11 +3182,54 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_EQ_OP(src_vregs.dimensions(), idx_vregs.dimensions());
TPU_ASSERT_EQ_OP(src_vregs.num_elements(), 1);

Location loc = dy_gather_op.getLoc();
SmallVector<int32_t> dimensions(dy_gather_op.getDimensions());
if (dy_gather_op.getType().getElementTypeBitWidth() == 8) {
if (dy_gather_op.getDimensions() != ArrayRef<int32_t>{0}) {
return dy_gather_op.emitOpError(
"Not implemented: 8-bit dynamic gather only supported along "
"dimension 0");
}
// Vreg shape is 8x128x4, and lowering only supports dimensions == {2, 0},
// i.e. byte index is in the upper bits and sublane index in the lower bits.
// However, the input indices effectively have sublane index in the upper
// bits and byte index in the lower bits.
VectorType i32_vreg_ty =
getNativeVregType(builder.getI32Type(), ctx.target_shape);
VectorType i8_vreg_ty =
getNativeVregType(builder.getI8Type(), ctx.target_shape);
idx_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
const int sublane_bits = llvm::Log2_64(ctx.target_shape[0]);
const int byte_bits = 2;
// This check ensures the shifting below does not change anything across
// bytes for relevant (byte and sublane) bits. Lets us mask just once.
CHECK_LE(sublane_bits + byte_bits + std::max(byte_bits, sublane_bits), 8);
// Zero out the high bits that specify neither byte nor index (they might
// not be zero since op semantics allow wrapping).
Value mask = getFullVector(
builder, loc, i8_vreg_ty,
builder.getI8IntegerAttr((1 << (byte_bits + sublane_bits)) - 1));
*v = builder.create<arith::AndIOp>(loc, mask, *v);
*v = builder.create<tpu::BitcastVregOp>(loc, i32_vreg_ty, *v);
Value shifted_byte = builder.create<arith::ShLIOp>(
loc, *v,
getFullVector(builder, loc, i32_vreg_ty,
builder.getI32IntegerAttr(sublane_bits)));
Value shifted_sublane = builder.create<arith::ShRUIOp>(
loc, *v,
getFullVector(builder, loc, i32_vreg_ty,
builder.getI32IntegerAttr(byte_bits)));
*v = builder.create<arith::OrIOp>(loc, shifted_byte, shifted_sublane);
*v = builder.create<tpu::BitcastVregOp>(loc, i8_vreg_ty, *v);
});
dimensions = SmallVector<int32_t>{2, 0};
}

xla::Array<Value> out_vregs(src_vregs.dimensions());
out_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<tpu::DynamicGatherOp>(
op.getLoc(), src_vregs(idxs).getType(), src_vregs(idxs),
idx_vregs(idxs), dy_gather_op.getDimension());
*v = builder.create<tpu::DynamicGatherOp>(loc, src_vregs(idxs).getType(),
src_vregs(idxs), idx_vregs(idxs),
dimensions);
});

dy_gather_op.replaceAllUsesWith(
Expand Down
25 changes: 13 additions & 12 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -965,22 +965,23 @@ class VectorLayoutInferer {
}

LogicalResult infer(tpu::DynamicGatherOp op) {
if (op.getType().getShape() != ArrayRef<int64_t>(target_shape_) &&
op.getType().getElementTypeBitWidth() != 32) {
return op.emitOpError(
"Not implemented: DynamicGatherOp only supports 32-bit VREG shape");
}
if (op.getDimension() != 0 && op.getDimension() != 1) {
return op.emitOpError(
"Not implemented: Only dimension 0 and 1 are supported");
}
// TODO(jevinjiang): we could preserve some offsets such as replicated
// offset but since we are forcing all operands and result to be the same
// layout, we can set all offsets to zero for now. Also maybe we should
// consider adding this to elementwise rule.
auto layout = VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_,
ImplicitDim::kNone);
setLayout(op, {layout, layout}, layout);
if (op.getType().getShape() == ArrayRef<int64_t>(target_shape_) &&
op.getType().getElementTypeBitWidth() == 32) {
VectorLayout layout(kNativeBitwidth, {0, 0}, default_tiling_,
ImplicitDim::kNone);
setLayout(op, {layout, layout}, layout);
} else if (op.getIndices().getType().getShape() ==
ArrayRef<int64_t>{4 * target_shape_[0], target_shape_[1]} &&
op.getType().getElementTypeBitWidth() == 8) {
VectorLayout layout(8, {0, 0}, nativeTiling(8), ImplicitDim::kNone);
setLayout(op, {layout, layout}, layout);
} else {
return op.emitOpError("Not implemented");
}
return success();
}

Expand Down
Loading