diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 505478b9ad72..a23d409dc9fe 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -470,20 +470,26 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, SameOperandsAndResultS let description = [{ Gathers elements from `source` using `indices`. - Given a shape `N0 x N1 x ...`, `output[i0, i1, ...]` is given by - `input[j0, j1, ...]` where `jn = indices[i0, i1, ...] mod Ni` for - `n = dimension` and `jn = in` otherwise. + The specified `dimensions` of `source` are collapsed together and indexed by + `indices`. - Similar to `np.take_along_axis`, except that OOB indices wrap. + Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by + `collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where + - `collapsed_source` is the result of collapsing `dimensions` of `source` + into a new trailing dimension of size `M`. + - `jk` is the subsequence of `in` for `n` not in `dimensions`. + + When a single dimension is specified, this is similar to + `np.take_along_axis`, except that OOB indices wrap. }]; let arguments = (ins AnyVectorOfNonZeroRank:$source, VectorOfNonZeroRankOf<[AnyInteger]>:$indices, - I32Attr:$dimension + DenseI32ArrayAttr:$dimensions ); let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ - $source `[` $indices `]` `in` $dimension attr-dict + $source `[` $indices `]` `in` $dimensions attr-dict `:` type($source) `,` type($indices) `->` type($output) }]; let hasVerifier = 1; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 3733bf5d4465..b3801d56bdb0 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1449,8 +1449,10 @@ LogicalResult DynamicGatherOp::verify() { if (getIndices().getType().getShape() != getIndices().getType().getShape()) { return emitOpError("Expected indices and result shapes must match"); } - if (!getIndices().getType().getElementType().isInteger(32)) { - return emitOpError("Not implemented: Only i32 indices supported"); + for (int32_t d : getDimensions()) { + if (d < 0 || d >= getSource().getType().getRank()) { + return emitOpError("Invalid dimension specified: ") << d; + } } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 6502a9c6682e..f00977af02e8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3147,10 +3147,15 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, auto dy_gather_op = cast(op); // TODO(jevinjiang): we need to think harder for general vector shape. - if (dy_gather_op.getType().getShape() != - ArrayRef(ctx.target_shape)) { + if (!(dy_gather_op.getType().getElementTypeBitWidth() == 32 && + dy_gather_op.getType().getShape() == + ArrayRef(ctx.target_shape)) && + !(dy_gather_op.getType().getElementTypeBitWidth() == 8 && + dy_gather_op.getType().getShape() == + ArrayRef{4 * ctx.target_shape[0], ctx.target_shape[1]})) { return op.emitOpError( - "Not implemented: DynamicGatherOp only supports 32-bit VREG shape"); + "Not implemented: DynamicGatherOp only supports 8- or 32-bit VREG " + "shape"); } if (src_layout != out_layout || idx_layout != out_layout) { @@ -3159,7 +3164,7 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, "result"); } - if (!out_layout.hasNaturalTopology(ctx.target_shape)) { + if (!out_layout.hasNativeTiling(ctx.target_shape)) { return op.emitOpError( "Not implemented: unsupported layout for DynamicGatherOp"); } @@ -3177,11 +3182,54 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(src_vregs.dimensions(), idx_vregs.dimensions()); TPU_ASSERT_EQ_OP(src_vregs.num_elements(), 1); + Location loc = dy_gather_op.getLoc(); + SmallVector dimensions(dy_gather_op.getDimensions()); + if (dy_gather_op.getType().getElementTypeBitWidth() == 8) { + if (dy_gather_op.getDimensions() != ArrayRef{0}) { + return dy_gather_op.emitOpError( + "Not implemented: 8-bit dynamic gather only supported along " + "dimension 0"); + } + // Vreg shape is 8x128x4, and lowering only supports dimensions == {2, 0}, + // i.e. byte index is in the upper bits and sublane index in the lower bits. + // However, the input indices effectively have sublane index in the upper + // bits and byte index in the lower bits. + VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); + VectorType i8_vreg_ty = + getNativeVregType(builder.getI8Type(), ctx.target_shape); + idx_vregs.Each([&](absl::Span idxs, Value *v) { + const int sublane_bits = llvm::Log2_64(ctx.target_shape[0]); + const int byte_bits = 2; + // This check ensures the shifting below does not change anything across + // bytes for relevant (byte and sublane) bits. Lets us mask just once. + CHECK_LE(sublane_bits + byte_bits + std::max(byte_bits, sublane_bits), 8); + // Zero out the high bits that specify neither byte nor index (they might + // not be zero since op semantics allow wrapping). + Value mask = getFullVector( + builder, loc, i8_vreg_ty, + builder.getI8IntegerAttr((1 << (byte_bits + sublane_bits)) - 1)); + *v = builder.create(loc, mask, *v); + *v = builder.create(loc, i32_vreg_ty, *v); + Value shifted_byte = builder.create( + loc, *v, + getFullVector(builder, loc, i32_vreg_ty, + builder.getI32IntegerAttr(sublane_bits))); + Value shifted_sublane = builder.create( + loc, *v, + getFullVector(builder, loc, i32_vreg_ty, + builder.getI32IntegerAttr(byte_bits))); + *v = builder.create(loc, shifted_byte, shifted_sublane); + *v = builder.create(loc, i8_vreg_ty, *v); + }); + dimensions = SmallVector{2, 0}; + } + xla::Array out_vregs(src_vregs.dimensions()); out_vregs.Each([&](absl::Span idxs, Value *v) { - *v = builder.create( - op.getLoc(), src_vregs(idxs).getType(), src_vregs(idxs), - idx_vregs(idxs), dy_gather_op.getDimension()); + *v = builder.create(loc, src_vregs(idxs).getType(), + src_vregs(idxs), idx_vregs(idxs), + dimensions); }); dy_gather_op.replaceAllUsesWith( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 976e31cb55f4..3b82a786063b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -965,22 +965,23 @@ class VectorLayoutInferer { } LogicalResult infer(tpu::DynamicGatherOp op) { - if (op.getType().getShape() != ArrayRef(target_shape_) && - op.getType().getElementTypeBitWidth() != 32) { - return op.emitOpError( - "Not implemented: DynamicGatherOp only supports 32-bit VREG shape"); - } - if (op.getDimension() != 0 && op.getDimension() != 1) { - return op.emitOpError( - "Not implemented: Only dimension 0 and 1 are supported"); - } // TODO(jevinjiang): we could preserve some offsets such as replicated // offset but since we are forcing all operands and result to be the same // layout, we can set all offsets to zero for now. Also maybe we should // consider adding this to elementwise rule. - auto layout = VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, - ImplicitDim::kNone); - setLayout(op, {layout, layout}, layout); + if (op.getType().getShape() == ArrayRef(target_shape_) && + op.getType().getElementTypeBitWidth() == 32) { + VectorLayout layout(kNativeBitwidth, {0, 0}, default_tiling_, + ImplicitDim::kNone); + setLayout(op, {layout, layout}, layout); + } else if (op.getIndices().getType().getShape() == + ArrayRef{4 * target_shape_[0], target_shape_[1]} && + op.getType().getElementTypeBitWidth() == 8) { + VectorLayout layout(8, {0, 0}, nativeTiling(8), ImplicitDim::kNone); + setLayout(op, {layout, layout}, layout); + } else { + return op.emitOpError("Not implemented"); + } return success(); }