@@ -3147,10 +3147,15 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
3147
3147
auto dy_gather_op = cast<tpu::DynamicGatherOp>(op);
3148
3148
3149
3149
// TODO(jevinjiang): we need to think harder for general vector shape.
3150
- if (dy_gather_op.getType ().getShape () !=
3151
- ArrayRef<int64_t >(ctx.target_shape )) {
3150
+ if (!(dy_gather_op.getType ().getElementTypeBitWidth () == 32 &&
3151
+ dy_gather_op.getType ().getShape () ==
3152
+ ArrayRef<int64_t >(ctx.target_shape )) &&
3153
+ !(dy_gather_op.getType ().getElementTypeBitWidth () == 8 &&
3154
+ dy_gather_op.getType ().getShape () ==
3155
+ ArrayRef<int64_t >{4 * ctx.target_shape [0 ], ctx.target_shape [1 ]})) {
3152
3156
return op.emitOpError (
3153
- " Not implemented: DynamicGatherOp only supports 32-bit VREG shape" );
3157
+ " Not implemented: DynamicGatherOp only supports 8- or 32-bit VREG "
3158
+ " shape" );
3154
3159
}
3155
3160
3156
3161
if (src_layout != out_layout || idx_layout != out_layout) {
@@ -3159,7 +3164,7 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
3159
3164
" result" );
3160
3165
}
3161
3166
3162
- if (!out_layout.hasNaturalTopology (ctx.target_shape )) {
3167
+ if (!out_layout.hasNativeTiling (ctx.target_shape )) {
3163
3168
return op.emitOpError (
3164
3169
" Not implemented: unsupported layout for DynamicGatherOp" );
3165
3170
}
@@ -3177,11 +3182,54 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op,
3177
3182
TPU_ASSERT_EQ_OP (src_vregs.dimensions (), idx_vregs.dimensions ());
3178
3183
TPU_ASSERT_EQ_OP (src_vregs.num_elements (), 1 );
3179
3184
3185
+ Location loc = dy_gather_op.getLoc ();
3186
+ SmallVector<int32_t > dimensions (dy_gather_op.getDimensions ());
3187
+ if (dy_gather_op.getType ().getElementTypeBitWidth () == 8 ) {
3188
+ if (dy_gather_op.getDimensions () != ArrayRef<int32_t >{0 }) {
3189
+ return dy_gather_op.emitOpError (
3190
+ " Not implemented: 8-bit dynamic gather only supported along "
3191
+ " dimension 0" );
3192
+ }
3193
+ // Vreg shape is 8x128x4, and lowering only supports dimensions == {2, 0},
3194
+ // i.e. byte index is in the upper bits and sublane index in the lower bits.
3195
+ // However, the input indices effectively have sublane index in the upper
3196
+ // bits and byte index in the lower bits.
3197
+ VectorType i32_vreg_ty =
3198
+ getNativeVregType (builder.getI32Type (), ctx.target_shape );
3199
+ VectorType i8_vreg_ty =
3200
+ getNativeVregType (builder.getI8Type (), ctx.target_shape );
3201
+ idx_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
3202
+ const int sublane_bits = llvm::Log2_64 (ctx.target_shape [0 ]);
3203
+ const int byte_bits = 2 ;
3204
+ // This check ensures the shifting below does not change anything across
3205
+ // bytes for relevant (byte and sublane) bits. Lets us mask just once.
3206
+ CHECK_LE (sublane_bits + byte_bits + std::max (byte_bits, sublane_bits), 8 );
3207
+ // Zero out the high bits that specify neither byte nor index (they might
3208
+ // not be zero since op semantics allow wrapping).
3209
+ Value mask = getFullVector (
3210
+ builder, loc, i8_vreg_ty,
3211
+ builder.getI8IntegerAttr ((1 << (byte_bits + sublane_bits)) - 1 ));
3212
+ *v = builder.create <arith::AndIOp>(loc, mask, *v);
3213
+ *v = builder.create <tpu::BitcastVregOp>(loc, i32_vreg_ty, *v);
3214
+ Value shifted_byte = builder.create <arith::ShLIOp>(
3215
+ loc, *v,
3216
+ getFullVector (builder, loc, i32_vreg_ty,
3217
+ builder.getI32IntegerAttr (sublane_bits)));
3218
+ Value shifted_sublane = builder.create <arith::ShRUIOp>(
3219
+ loc, *v,
3220
+ getFullVector (builder, loc, i32_vreg_ty,
3221
+ builder.getI32IntegerAttr (byte_bits)));
3222
+ *v = builder.create <arith::OrIOp>(loc, shifted_byte, shifted_sublane);
3223
+ *v = builder.create <tpu::BitcastVregOp>(loc, i8_vreg_ty, *v);
3224
+ });
3225
+ dimensions = SmallVector<int32_t >{2 , 0 };
3226
+ }
3227
+
3180
3228
xla::Array<Value> out_vregs (src_vregs.dimensions ());
3181
3229
out_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
3182
- *v = builder.create <tpu::DynamicGatherOp>(
3183
- op. getLoc (), src_vregs (idxs). getType (), src_vregs (idxs),
3184
- idx_vregs (idxs), dy_gather_op. getDimension () );
3230
+ *v = builder.create <tpu::DynamicGatherOp>(loc, src_vregs (idxs). getType (),
3231
+ src_vregs (idxs), idx_vregs (idxs),
3232
+ dimensions );
3185
3233
});
3186
3234
3187
3235
dy_gather_op.replaceAllUsesWith (
0 commit comments