@@ -1674,136 +1674,75 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
1674
1674
} // namespace
1675
1675
1676
1676
namespace {
1677
- class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1678
- public:
1679
- using OpConversionPattern::OpConversionPattern;
1680
- LogicalResult
1681
- matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1682
- ConversionPatternRewriter &rewriter) const override {
1683
-
1684
- Location loc = op->getLoc ();
1685
- Value lhs = adaptor.getSelf ();
1686
- Value rhs = op->getOperand (1 );
1687
-
1688
- if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1689
- return failure ();
1690
- }
1691
- auto lhsType = cast<RankedTensorType>(lhs.getType ());
1692
- auto rhsType = cast<RankedTensorType>(rhs.getType ());
1693
-
1694
- auto lhsTorchType = cast<ValueTensorType>(op.getSelf ().getType ());
1695
- auto rhsTorchType = cast<ValueTensorType>(op.getOperand (1 ).getType ());
1696
-
1697
- // Get the rank of both matrix.
1698
- unsigned lhsRank = lhsType.getRank ();
1699
- unsigned rhsRank = rhsType.getRank ();
1700
-
1701
- Value lhsZeroPoint, rhsZeroPoint;
1702
- getZeroPoint (op.getSelf (), lhsZeroPoint);
1703
- getZeroPoint (op.getOperand (1 ), rhsZeroPoint);
1704
-
1705
- if (static_cast <bool >(lhsZeroPoint) != static_cast <bool >(rhsZeroPoint)) {
1706
- return rewriter.notifyMatchFailure (
1707
- op, " unsupported: aten.outer with mixed quantization" );
1708
- }
1709
-
1710
- bool isUnsigned = torch_to_linalg::isUnsignedTorchType (lhsTorchType);
1711
- bool isUnsignedR = torch_to_linalg::isUnsignedTorchType (rhsTorchType);
1712
-
1713
- if (!lhsZeroPoint && lhsTorchType.getDtype () != rhsTorchType.getDtype ()) {
1714
- // Allows quantized types to mismatch
1715
- return rewriter.notifyMatchFailure (
1716
- op, " unsupported: aten.outer with different input element types" );
1717
- }
1718
-
1719
- Type newResultType = getTypeConverter ()->convertType (op.getType ());
1720
- auto resultType = cast<RankedTensorType>(newResultType);
1721
- Type elementType = resultType.getElementType ();
1722
-
1723
- // Quantized case
1724
- if (lhsZeroPoint) {
1725
- // get each zero point ready to pass to a quantized_matmul
1726
- lhsZeroPoint = typeConverter->materializeTargetConversion (
1727
- rewriter, loc,
1728
- getTypeConverter ()->convertType (lhsZeroPoint.getType ()),
1729
- lhsZeroPoint);
1730
- rhsZeroPoint = typeConverter->materializeTargetConversion (
1731
- rewriter, loc,
1732
- getTypeConverter ()->convertType (rhsZeroPoint.getType ()),
1733
- rhsZeroPoint);
1734
- lhsZeroPoint = rewriter.create <arith::TruncIOp>(
1735
- loc, rewriter.getI32Type (), lhsZeroPoint);
1736
- rhsZeroPoint = rewriter.create <arith::TruncIOp>(
1737
- loc, rewriter.getI32Type (), rhsZeroPoint);
1738
-
1739
- // change uint8 quantization -> int8 quantization
1740
- int64_t numBits =
1741
- cast<mlir::IntegerType>(lhsType.getElementType ()).getWidth ();
1742
- signShift (rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
1743
- numBits = cast<mlir::IntegerType>(rhsType.getElementType ()).getWidth ();
1744
- signShift (rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
1745
-
1746
- if (lhsRank == 1 && rhsRank == 1 ) {
1747
- int64_t lhsDim = lhsType.getShape ()[0 ];
1748
- int64_t rhsDim = rhsType.getShape ()[0 ];
1677
+ class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1678
+ public:
1679
+ using OpConversionPattern::OpConversionPattern;
1680
+ LogicalResult
1681
+ matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1682
+ ConversionPatternRewriter &rewriter) const override {
1749
1683
1750
- // Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m]
1751
- auto lhsUnsqueezeType = RankedTensorType::get ({lhsDim, 1 }, lhsType.getElementType ());
1752
- auto rhsUnsqueezeType = RankedTensorType::get ({1 , rhsDim}, rhsType.getElementType ());
1753
- SmallVector<ReassociationIndices> reassociation = {{0 , 1 }};
1754
- lhs = rewriter.create <tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1755
- rhs = rewriter.create <tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1756
-
1757
- // Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator.
1758
- Value lhsDimVal = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
1759
- Value rhsDimVal = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
1760
- Value zeroTensor = createZeroInitTensor (rewriter, loc,
1761
- ValueRange{lhsDimVal, rhsDimVal},
1762
- elementType);
1763
-
1764
- // Use the quantized version of matmul.
1765
- Value outerProd = rewriter.create <linalg::QuantizedMatmulOp>(
1766
- loc, zeroTensor.getType (),
1767
- ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
1768
- zeroTensor).getResult (0 );
1769
-
1770
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1771
- return success ();
1772
- }
1773
- return rewriter.notifyMatchFailure (op, " unsupported: quantized aten.outer op case" );
1774
- }
1775
-
1776
-
1777
- // Non Quantized Outter Product
1778
- if (lhsRank == 1 && rhsRank == 1 ) {
1779
- int64_t lhsDim = lhsType.getShape ()[0 ];
1780
- int64_t rhsDim = rhsType.getShape ()[0 ];
1781
-
1782
- // Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m]
1783
- auto lhsUnsqueezeType = RankedTensorType::get ({lhsDim, 1 }, lhsType.getElementType ());
1784
- auto rhsUnsqueezeType = RankedTensorType::get ({1 , rhsDim}, rhsType.getElementType ());
1785
- SmallVector<ReassociationIndices> reassociation = {{0 , 1 }};
1786
- lhs = rewriter.create <tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1787
- rhs = rewriter.create <tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1788
-
1789
- // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1790
- Value lhsDimVal = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
1791
- Value rhsDimVal = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
1792
- Value zeroTensor = createZeroInitTensor (rewriter, loc,
1793
- ValueRange{lhsDimVal, rhsDimVal},
1794
- elementType);
1795
-
1796
- // Use linalg::MatmulOp to compute the outer product.
1797
- Value outerProd = rewriter.create <linalg::MatmulOp>(
1798
- loc, zeroTensor.getType (), ValueRange{lhs, rhs}, zeroTensor).getResult (0 );
1799
-
1800
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1801
- return success ();
1802
- }
1803
-
1684
+ Location loc = op->getLoc ();
1685
+ Value lhs = adaptor.getSelf ();
1686
+ Value rhs = op->getOperand (1 );
1687
+
1688
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1804
1689
return failure ();
1805
1690
}
1806
- };
1691
+ auto lhsType = cast<RankedTensorType>(lhs.getType ());
1692
+ auto rhsType = cast<RankedTensorType>(rhs.getType ());
1693
+
1694
+ if (!lhsType || !rhsType)
1695
+ return rewriter.notifyMatchFailure (op,
1696
+ " outer: expected ranked tensor types" );
1697
+ if (lhsType.getRank () != 1 || rhsType.getRank () != 1 )
1698
+ return rewriter.notifyMatchFailure (
1699
+ op, " outer: expected 1D tensors for outer op lowering" );
1700
+
1701
+ Value lhsDim = getDimOp (rewriter, loc, lhs, 1 );
1702
+ Value rhsDim = getDimOp (rewriter, loc, rhs, 1 );
1703
+ Type elementType = lhsType.getElementType ();
1704
+ Type newResultType = getTypeConverter ()->convertType (op.getType ());
1705
+
1706
+ // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1707
+ Value zeroTensor = createZeroInitTensor (
1708
+ rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType);
1709
+
1710
+ // Set up affine indexing maps:
1711
+ // We create a 2D loop iteration space. For the lhs, we use the first index
1712
+ // (i), for the rhs, the second index (j), and for the result, both (i, j).
1713
+ AffineMap mapLhs =
1714
+ AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (0 )},
1715
+ rewriter.getContext ());
1716
+ AffineMap mapRhs =
1717
+ AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (1 )},
1718
+ rewriter.getContext ());
1719
+ AffineMap mapOut =
1720
+ AffineMap::getMultiDimIdentityMap (2 , rewriter.getContext ());
1721
+
1722
+ SmallVector<utils::IteratorType, 2 > iteratorTypes = {
1723
+ utils::IteratorType::parallel, utils::IteratorType::parallel};
1724
+
1725
+ Value outerProd =
1726
+ rewriter
1727
+ .create <linalg::GenericOp>(
1728
+ loc, zeroTensor.getType (),
1729
+ /* inputs=*/ ValueRange{lhsDim, rhsDim},
1730
+ /* outputs=*/ zeroTensor,
1731
+ /* indexingMaps=*/
1732
+ SmallVector<AffineMap, 3 >{mapLhs, mapRhs, mapOut},
1733
+ /* iteratortType=*/ iteratorTypes,
1734
+ [&](OpBuilder &b, Location loc, ValueRange args) {
1735
+ Value lhsElem = args[0 ];
1736
+ Value rhsElem = args[1 ];
1737
+ Value mult = b.create <arith::MulFOp>(loc, lhsElem, rhsElem);
1738
+ b.create <linalg::YieldOp>(loc, mult);
1739
+ })
1740
+ .getResult (0 );
1741
+
1742
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1743
+ return success ();
1744
+ }
1745
+ };
1807
1746
} // namespace
1808
1747
1809
1748
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality (
0 commit comments