@@ -216,6 +216,41 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
216
216
return indicesTf.getResult ();
217
217
}
218
218
219
+ Value wrapIndicesAroundMax (Value index, int maxIndex, Operation *op,
220
+ PatternRewriter &rewriter) {
221
+ // performs the operation : index = index % maxIndex to wrap index around
222
+ // maxIndex
223
+
224
+ auto maxIndexValue =
225
+ tosa::getConstTensor<int32_t >(rewriter, op, maxIndex, {}).value ();
226
+ auto maxIndexValueMinusOne =
227
+ tosa::getConstTensor<int32_t >(rewriter, op, maxIndex - 1 , {}).value ();
228
+
229
+ auto indexType = dyn_cast<RankedTensorType>(index .getType ());
230
+ auto boolType = indexType.clone (rewriter.getIntegerType (1 ));
231
+
232
+ auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
233
+ rewriter, op->getLoc (), boolType, index , maxIndexValueMinusOne);
234
+ auto wrappedBeyondMaxIndicesQuotient =
235
+ tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc (), indexType,
236
+ index , maxIndexValue)
237
+ .getResult ();
238
+ auto wrappedBeyondMaxIndicesQuotientTimesIndices =
239
+ tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc (), indexType,
240
+ wrappedBeyondMaxIndicesQuotient,
241
+ maxIndexValue, /* shift=*/ 0 )
242
+ .getResult ();
243
+ auto wrappedBeyondMaxIndices =
244
+ tosa::CreateOpAndInfer<tosa::SubOp>(
245
+ rewriter, op->getLoc (), indexType, index ,
246
+ wrappedBeyondMaxIndicesQuotientTimesIndices)
247
+ .getResult ();
248
+
249
+ return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc (),
250
+ indexType, isBeyondMaxIndices,
251
+ wrappedBeyondMaxIndices, index );
252
+ }
253
+
219
254
// Lowers Gather operators to a sequence of TOSA ops.
220
255
// taken from
221
256
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
@@ -403,14 +438,17 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
403
438
flattenedIndicesReduceOp.getResult (),
404
439
rewriter.getDenseI64ArrayAttr (tosaIndicesShape));
405
440
441
+ auto wrappedIndices = wrapIndicesAroundMax (tosaIndicesReshapeOp.getResult (),
442
+ K + 1 , op, rewriter);
443
+
406
444
// Now the gather op itself
407
445
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
408
446
// tensor<1x8x1xf32>
409
447
auto tosaGatherOp = tosa::CreateOpAndInfer<tosa::GatherOp>(
410
448
rewriter, op->getLoc (),
411
449
GetTypeFromTensorShape (tosaGatherResultShape,
412
450
resultType.getElementType ()),
413
- tosaValuesReshapeOp.getResult (), tosaIndicesReshapeOp. getResult () );
451
+ tosaValuesReshapeOp.getResult (), wrappedIndices );
414
452
415
453
// Finally, reshape back to the original output shape of [Indices,
416
454
// ParamChannels]. %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} :
0 commit comments