diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 335277aa4462..2217d99e1a28 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1894,6 +1894,59 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern { }; } // namespace +// Decompose 'aten.outer' into 'aten.unsqueeze', 'aten.matmul' + +namespace { +class DecomposeAtenOuterOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOuterOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value input = op.getSelf(); + Value vec2 = op.getVec2(); + Type opType = op.getType(); + + auto inputType = cast(input.getType()); + auto vec2Type = cast(vec2.getType()); + + // Check if both tensors are 1-dimensional + SmallVector inputShape(inputType.getSizes()); + SmallVector vec2Shape(vec2Type.getSizes()); + + if (inputShape.size() == 1 && vec2Shape.size() == 1) { + + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); // Dimension index + SmallVector inputMatrixShape = {inputShape[0], 1}; + Type inputMatrixType = inputType.getWithSizesAndDtype( + inputMatrixShape, inputType.getOptionalDtype()); + + Value inputMatrix = + rewriter.create(loc, inputMatrixType, input, one); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + SmallVector vec2MatrixShape = {1, vec2Shape[0]}; + Type vec2MatrixType = vec2Type.getWithSizesAndDtype( + vec2MatrixShape, vec2Type.getOptionalDtype()); + + Value vec2Matrix = + rewriter.create(loc, vec2MatrixType, vec2, zero); + + rewriter.replaceOpWithNewOp(op, opType, inputMatrix, + vec2Matrix); + return success(); + } else { + return failure(); + } + + return success(); + } +}; +} // namespace + namespace { // Decompose aten.atleast_2d into: aten.reshape. See // https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604 @@ -11591,6 +11644,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6695f2964b65..95bacd9fc9e6 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -379,6 +379,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6904b4acb3c6..00a493c0db58 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3810,6 +3810,7 @@ } ONNX_TOSA_XFAIL_SET = { + "AtenOuter_basic", "AtenFftRfft2DLastDim_basic", "AtenFftRfft2DMiddleDim_basic", "AtenNonzero1DDynamicModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 17240cf953df..79bacc2bb9c7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -918,3 +918,61 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: AtenLinalgCrossDynamic()) def AtenLinalgCrossDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) + + +# ============================================================================== + + +class AtenOuter(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.outer(lhs, rhs) + + +@register_test_case(module_factory=lambda: AtenOuter()) +def AtenOuter_basic(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(3)) + + +# ============================================================================== + + +class AtenOuterDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.outer(lhs, rhs) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(5), tu.rand(5)) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_lhs_larger(module, tu: TestUtils): + module.forward(tu.rand(7), tu.rand(4)) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_rhs_larger(module, tu: TestUtils): + module.forward(tu.rand(2), tu.rand(6))