Skip to content

Commit f2c205f

Browse files
committed
submit the llvm#3902 to local repo (llvm#5)
* Decompose lstm and gru. * Add tests and update xfail_sets.py * Rebase main * Fix casting for arith.cmpi operands to be of same type.
1 parent 0176481 commit f2c205f

File tree

5 files changed

+227
-32
lines changed

5 files changed

+227
-32
lines changed

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,21 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
417417
};
418418
} // namespace
419419

420+
static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index,
421+
Value input, int64_t dim) {
422+
// performs the operation : index = index % maxIndex to wrap index around
423+
// maxIndex
424+
Value maxIndexValue = getDimOp(b, loc, input, dim);
425+
maxIndexValue =
426+
b.createOrFold<arith::IndexCastOp>(loc, index.getType(), maxIndexValue);
427+
Value isBeyondMaxIndices = b.createOrFold<arith::CmpIOp>(
428+
loc, arith::CmpIPredicate::sge, index, maxIndexValue);
429+
Value wrappedIndices =
430+
b.createOrFold<arith::RemSIOp>(loc, index, maxIndexValue);
431+
return b.createOrFold<arith::SelectOp>(loc, isBeyondMaxIndices,
432+
wrappedIndices, index);
433+
}
434+
420435
namespace {
421436
// Let's say we have an input tensor: initialized with some random values of
422437
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
@@ -478,16 +493,17 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
478493

479494
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr},
480495
rewriter.getContext());
481-
482496
Value finalRes =
483497
rewriter
484498
.create<linalg::GenericOp>(
485499
loc, initTensor.getType(), ValueRange{indices}, initTensor,
486500
/*indexingMaps=*/indexingMaps,
487501
/*iteratorTypes=*/iteratorTypes,
488502
[&](OpBuilder &b, Location loc, ValueRange args) {
489-
Value index = rewriter.create<arith::IndexCastOp>(
490-
loc, rewriter.getIndexType(), args[0]);
503+
Value index =
504+
wrapIndicesAroundMax(b, loc, args[0], input, dimInt);
505+
index = rewriter.create<arith::IndexCastOp>(
506+
loc, rewriter.getIndexType(), index);
491507
SmallVector<Value> indexTarget;
492508
for (unsigned i = 0; i < inputRank; i++)
493509
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4219,6 +4219,42 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
42194219
return success();
42204220
}
42214221

4222+
Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
4223+
ConversionPatternRewriter &rewriter) {
4224+
// performs the operation : index = index % maxIndex to wrap index around
4225+
// maxIndex
4226+
4227+
auto maxIndexValue =
4228+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
4229+
auto maxIndexValueMinusOne =
4230+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();
4231+
4232+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4233+
auto boolType = indexType.clone(rewriter.getIntegerType(1));
4234+
4235+
auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4236+
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
4237+
auto wrappedBeyondMaxIndicesQuotient =
4238+
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
4239+
index, maxIndexValue)
4240+
.getResult();
4241+
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
4242+
tosa::createMulOpAndCast(rewriter, op, indexType,
4243+
wrappedBeyondMaxIndicesQuotient,
4244+
wrappedBeyondMaxIndicesQuotient,
4245+
/*shift=*/0)
4246+
.getResult();
4247+
auto wrappedBeyondMaxIndices =
4248+
tosa::CreateOpAndInfer<tosa::SubOp>(
4249+
rewriter, op->getLoc(), indexType, index,
4250+
wrappedBeyondMaxIndicesQuotientTimesIndices)
4251+
.getResult();
4252+
4253+
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
4254+
indexType, isBeyondMaxIndices,
4255+
wrappedBeyondMaxIndices, index);
4256+
}
4257+
42224258
template <>
42234259
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
42244260
AtenIndexSelectOp op, OpAdaptor adaptor,
@@ -4263,6 +4299,10 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
42634299
.value();
42644300
}
42654301

4302+
int64_t selfNumElems = std::accumulate(inputShape.begin(), inputShape.end(),
4303+
1, std::multiplies<int64_t>());
4304+
index = wrapIndicesAroundMax(index, selfNumElems, op, rewriter);
4305+
42664306
// Get positive dim
42674307
int64_t dim;
42684308
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
@@ -7689,10 +7729,12 @@ LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
76897729
// coord_i_n * stride[n]
76907730
int32_t index = offset;
76917731
int64_t coordFinder = i;
7732+
76927733
for (int64_t dim = 0; dim < outputRank; dim++) {
76937734
int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1];
76947735
index += indexCoord * stride[outputRank - dim - 1];
76957736
coordFinder /= outputSize[outputRank - dim - 1];
7737+
index = (index % selfNumElems);
76967738
}
76977739
targetIndicesVec.push_back(index);
76987740
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@
489489
"SplitTensorNegativeDimModule_basic",
490490
"SplitWithSizesListUnpackModule_basic",
491491
"SplitWithSizes_Module_basic",
492+
"AsStridedWithOffsetModule_basic",
492493
"AdaptiveAvgPool1dGeneralDynamic_basic",
493494
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
494495
"AdaptiveAvgPool1dStaticLargerOutput_basic",
@@ -914,6 +915,7 @@
914915
"SplitTensorNegativeDimModule_basic",
915916
"SplitWithSizesListUnpackModule_basic",
916917
"SplitWithSizes_Module_basic",
918+
"AsStridedWithOffsetModule_basic",
917919
"Unfold_Module_basic",
918920
"Unfold_Module_Rank_4",
919921
"Unfold_Module_Rank_Zero_basic",
@@ -1824,6 +1826,7 @@
18241826
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
18251827
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
18261828
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
1829+
"AsStridedWithOffsetModule_basic",
18271830
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
18281831
"ElementwiseCosIntModule_basic",
18291832
"ElementwiseReciprocalIntModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,3 +1144,32 @@ def forward(self, x):
11441144
@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule())
11451145
def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils):
11461146
module.forward(tu.rand(2, 5))
1147+
1148+
1149+
# ==============================================================================
1150+
1151+
1152+
class AsStridedWithOffsetModule(torch.nn.Module):
1153+
def __init__(self):
1154+
super().__init__()
1155+
1156+
@export
1157+
@annotate_args(
1158+
[
1159+
None,
1160+
([2, 6, 60], torch.float32, True),
1161+
]
1162+
)
1163+
def forward(self, x):
1164+
output_size = [6, 20]
1165+
stride = [60, 1]
1166+
slice = torch.ops.aten.slice.Tensor(x, 0, 1, 2)
1167+
squeeze = torch.ops.aten.squeeze.dim(slice, 0)
1168+
return torch.ops.aten.as_strided(
1169+
squeeze, size=output_size, stride=stride, storage_offset=360
1170+
)
1171+
1172+
1173+
@register_test_case(module_factory=lambda: AsStridedWithOffsetModule())
1174+
def AsStridedWithOffsetModule_basic(module, tu: TestUtils):
1175+
module.forward(torch.rand(2, 6, 60))

0 commit comments

Comments
 (0)