Skip to content

Commit 8fcdfbe

Browse files
committed
Wrap indices around max value for tosa.gather creation.
1 parent 714b7fc commit 8fcdfbe

File tree

3 files changed

+142
-100
lines changed

3 files changed

+142
-100
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4006,41 +4006,6 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
40064006
return success();
40074007
}
40084008

4009-
Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
4010-
ConversionPatternRewriter &rewriter) {
4011-
// performs the operation : index = index % maxIndex to wrap index around
4012-
// maxIndex
4013-
4014-
auto maxIndexValue =
4015-
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
4016-
auto maxIndexValueMinusOne =
4017-
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();
4018-
4019-
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4020-
auto boolType = indexType.clone(rewriter.getIntegerType(1));
4021-
4022-
auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4023-
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
4024-
auto wrappedBeyondMaxIndicesQuotient =
4025-
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
4026-
index, maxIndexValue)
4027-
.getResult();
4028-
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
4029-
tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), indexType,
4030-
wrappedBeyondMaxIndicesQuotient,
4031-
maxIndexValue, /*shift=*/0)
4032-
.getResult();
4033-
auto wrappedBeyondMaxIndices =
4034-
tosa::CreateOpAndInfer<tosa::SubOp>(
4035-
rewriter, op->getLoc(), indexType, index,
4036-
wrappedBeyondMaxIndicesQuotientTimesIndices)
4037-
.getResult();
4038-
4039-
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
4040-
indexType, isBeyondMaxIndices,
4041-
wrappedBeyondMaxIndices, index);
4042-
}
4043-
40444009
template <>
40454010
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
40464011
AtenIndexSelectOp op, OpAdaptor adaptor,
@@ -4084,10 +4049,6 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
40844049
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
40854050
}
40864051

4087-
int64_t selfNumElems = std::accumulate(inputShape.begin(), inputShape.end(),
4088-
1, std::multiplies<int64_t>());
4089-
index = wrapIndicesAroundMax(index, selfNumElems, op, rewriter);
4090-
40914052
// Get positive dim
40924053
int64_t dim;
40934054
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,41 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
216216
return indicesTf.getResult();
217217
}
218218

219+
Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
220+
PatternRewriter &rewriter) {
221+
// performs the operation : index = index % maxIndex to wrap index around
222+
// maxIndex
223+
224+
auto maxIndexValue =
225+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
226+
auto maxIndexValueMinusOne =
227+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();
228+
229+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
230+
auto boolType = indexType.clone(rewriter.getIntegerType(1));
231+
232+
auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
233+
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
234+
auto wrappedBeyondMaxIndicesQuotient =
235+
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
236+
index, maxIndexValue)
237+
.getResult();
238+
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
239+
tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), indexType,
240+
wrappedBeyondMaxIndicesQuotient,
241+
maxIndexValue, /*shift=*/0)
242+
.getResult();
243+
auto wrappedBeyondMaxIndices =
244+
tosa::CreateOpAndInfer<tosa::SubOp>(
245+
rewriter, op->getLoc(), indexType, index,
246+
wrappedBeyondMaxIndicesQuotientTimesIndices)
247+
.getResult();
248+
249+
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
250+
indexType, isBeyondMaxIndices,
251+
wrappedBeyondMaxIndices, index);
252+
}
253+
219254
// Lowers Gather operators to a sequence of TOSA ops.
220255
// taken from
221256
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
@@ -403,14 +438,17 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
403438
flattenedIndicesReduceOp.getResult(),
404439
rewriter.getDenseI64ArrayAttr(tosaIndicesShape));
405440

441+
auto wrappedIndices = wrapIndicesAroundMax(tosaIndicesReshapeOp.getResult(),
442+
K + 1, op, rewriter);
443+
406444
// Now the gather op itself
407445
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
408446
// tensor<1x8x1xf32>
409447
auto tosaGatherOp = tosa::CreateOpAndInfer<tosa::GatherOp>(
410448
rewriter, op->getLoc(),
411449
GetTypeFromTensorShape(tosaGatherResultShape,
412450
resultType.getElementType()),
413-
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult());
451+
tosaValuesReshapeOp.getResult(), wrappedIndices);
414452

415453
// Finally, reshape back to the original output shape of [Indices,
416454
// ParamChannels]. %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} :

0 commit comments

Comments
 (0)