Skip to content

Commit d167abd

Browse files
committed
submit the llvm#3902 to local repo (#5)
* Decompose lstm and gru. * Add tests and update xfail_sets.py * Rebase main * Fix casting for arith.cmpi operands to be of same type.
1 parent 972080e commit d167abd

File tree

5 files changed

+227
-32
lines changed

5 files changed

+227
-32
lines changed

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,21 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
417417
};
418418
} // namespace
419419

420+
static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index,
421+
Value input, int64_t dim) {
422+
// performs the operation : index = index % maxIndex to wrap index around
423+
// maxIndex
424+
Value maxIndexValue = getDimOp(b, loc, input, dim);
425+
maxIndexValue =
426+
b.createOrFold<arith::IndexCastOp>(loc, index.getType(), maxIndexValue);
427+
Value isBeyondMaxIndices = b.createOrFold<arith::CmpIOp>(
428+
loc, arith::CmpIPredicate::sge, index, maxIndexValue);
429+
Value wrappedIndices =
430+
b.createOrFold<arith::RemSIOp>(loc, index, maxIndexValue);
431+
return b.createOrFold<arith::SelectOp>(loc, isBeyondMaxIndices,
432+
wrappedIndices, index);
433+
}
434+
420435
namespace {
421436
// Let's say we have an input tensor: initialized with some random values of
422437
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
@@ -478,16 +493,17 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
478493

479494
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr},
480495
rewriter.getContext());
481-
482496
Value finalRes =
483497
rewriter
484498
.create<linalg::GenericOp>(
485499
loc, initTensor.getType(), ValueRange{indices}, initTensor,
486500
/*indexingMaps=*/indexingMaps,
487501
/*iteratorTypes=*/iteratorTypes,
488502
[&](OpBuilder &b, Location loc, ValueRange args) {
489-
Value index = rewriter.create<arith::IndexCastOp>(
490-
loc, rewriter.getIndexType(), args[0]);
503+
Value index =
504+
wrapIndicesAroundMax(b, loc, args[0], input, dimInt);
505+
index = rewriter.create<arith::IndexCastOp>(
506+
loc, rewriter.getIndexType(), index);
491507
SmallVector<Value> indexTarget;
492508
for (unsigned i = 0; i < inputRank; i++)
493509
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4290,6 +4290,42 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
42904290
return success();
42914291
}
42924292

4293+
Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
4294+
ConversionPatternRewriter &rewriter) {
4295+
// performs the operation : index = index % maxIndex to wrap index around
4296+
// maxIndex
4297+
4298+
auto maxIndexValue =
4299+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
4300+
auto maxIndexValueMinusOne =
4301+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();
4302+
4303+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4304+
auto boolType = indexType.clone(rewriter.getIntegerType(1));
4305+
4306+
auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4307+
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
4308+
auto wrappedBeyondMaxIndicesQuotient =
4309+
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
4310+
index, maxIndexValue)
4311+
.getResult();
4312+
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
4313+
tosa::createMulOpAndCast(rewriter, op, indexType,
4314+
wrappedBeyondMaxIndicesQuotient,
4315+
wrappedBeyondMaxIndicesQuotient,
4316+
/*shift=*/0)
4317+
.getResult();
4318+
auto wrappedBeyondMaxIndices =
4319+
tosa::CreateOpAndInfer<tosa::SubOp>(
4320+
rewriter, op->getLoc(), indexType, index,
4321+
wrappedBeyondMaxIndicesQuotientTimesIndices)
4322+
.getResult();
4323+
4324+
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
4325+
indexType, isBeyondMaxIndices,
4326+
wrappedBeyondMaxIndices, index);
4327+
}
4328+
42934329
template <>
42944330
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
42954331
AtenIndexSelectOp op, OpAdaptor adaptor,
@@ -4334,6 +4370,10 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
43344370
.value();
43354371
}
43364372

4373+
int64_t selfNumElems = std::accumulate(inputShape.begin(), inputShape.end(),
4374+
1, std::multiplies<int64_t>());
4375+
index = wrapIndicesAroundMax(index, selfNumElems, op, rewriter);
4376+
43374377
// Get positive dim
43384378
int64_t dim;
43394379
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
@@ -7705,10 +7745,12 @@ LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
77057745
// coord_i_n * stride[n]
77067746
int32_t index = offset;
77077747
int64_t coordFinder = i;
7748+
77087749
for (int64_t dim = 0; dim < outputRank; dim++) {
77097750
int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1];
77107751
index += indexCoord * stride[outputRank - dim - 1];
77117752
coordFinder /= outputSize[outputRank - dim - 1];
7753+
index = (index % selfNumElems);
77127754
}
77137755
targetIndicesVec.push_back(index);
77147756
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@
489489
"SplitTensorNegativeDimModule_basic",
490490
"SplitWithSizesListUnpackModule_basic",
491491
"SplitWithSizes_Module_basic",
492+
"AsStridedWithOffsetModule_basic",
492493
"AdaptiveAvgPool1dGeneralDynamic_basic",
493494
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
494495
"AdaptiveAvgPool1dStaticLargerOutput_basic",
@@ -906,6 +907,7 @@
906907
"SplitTensorNegativeDimModule_basic",
907908
"SplitWithSizesListUnpackModule_basic",
908909
"SplitWithSizes_Module_basic",
910+
"AsStridedWithOffsetModule_basic",
909911
"Unfold_Module_basic",
910912
"Unfold_Module_Rank_4",
911913
"Unfold_Module_Rank_Zero_basic",
@@ -1803,6 +1805,7 @@
18031805
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
18041806
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
18051807
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
1808+
"AsStridedWithOffsetModule_basic",
18061809
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
18071810
"ElementwiseCosIntModule_basic",
18081811
"ElementwiseReciprocalIntModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,3 +1144,32 @@ def forward(self, x):
11441144
@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule())
11451145
def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils):
11461146
module.forward(tu.rand(2, 5))
1147+
1148+
1149+
# ==============================================================================
1150+
1151+
1152+
class AsStridedWithOffsetModule(torch.nn.Module):
1153+
def __init__(self):
1154+
super().__init__()
1155+
1156+
@export
1157+
@annotate_args(
1158+
[
1159+
None,
1160+
([2, 6, 60], torch.float32, True),
1161+
]
1162+
)
1163+
def forward(self, x):
1164+
output_size = [6, 20]
1165+
stride = [60, 1]
1166+
slice = torch.ops.aten.slice.Tensor(x, 0, 1, 2)
1167+
squeeze = torch.ops.aten.squeeze.dim(slice, 0)
1168+
return torch.ops.aten.as_strided(
1169+
squeeze, size=output_size, stride=stride, storage_offset=360
1170+
)
1171+
1172+
1173+
@register_test_case(module_factory=lambda: AsStridedWithOffsetModule())
1174+
def AsStridedWithOffsetModule_basic(module, tu: TestUtils):
1175+
module.forward(torch.rand(2, 6, 60))

0 commit comments

Comments
 (0)