Skip to content

Commit cda896e

Browse files
rootroot
root
authored and
root
committed
Addressed the comments:
- Rewrote the ConvertAtenOuterOp without unsqueezing - Replaced linalg::MatmulOp with linalg::GenericOp for buidling result of the op - Added error messages for - Added test case in e2e tests - placed in matmul.py
1 parent 2348344 commit cda896e

File tree

2 files changed

+90
-127
lines changed
  • lib/Conversion/TorchToLinalg
  • projects/pt1/python/torch_mlir_e2e_test/test_suite

2 files changed

+90
-127
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 66 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,136 +1674,75 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
16741674
} // namespace
16751675

16761676
namespace {
1677-
class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
1678-
public:
1679-
using OpConversionPattern::OpConversionPattern;
1680-
LogicalResult
1681-
matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor,
1682-
ConversionPatternRewriter &rewriter) const override {
1683-
1684-
Location loc = op->getLoc();
1685-
Value lhs = adaptor.getSelf();
1686-
Value rhs = op->getOperand(1);
1687-
1688-
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
1689-
return failure();
1690-
}
1691-
auto lhsType = cast<RankedTensorType>(lhs.getType());
1692-
auto rhsType = cast<RankedTensorType>(rhs.getType());
1693-
1694-
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
1695-
auto rhsTorchType = cast<ValueTensorType>(op.getOperand(1).getType());
1696-
1697-
// Get the rank of both matrix.
1698-
unsigned lhsRank = lhsType.getRank();
1699-
unsigned rhsRank = rhsType.getRank();
1700-
1701-
Value lhsZeroPoint, rhsZeroPoint;
1702-
getZeroPoint(op.getSelf(), lhsZeroPoint);
1703-
getZeroPoint(op.getOperand(1), rhsZeroPoint);
1704-
1705-
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
1706-
return rewriter.notifyMatchFailure(
1707-
op, "unsupported: aten.outer with mixed quantization");
1708-
}
1709-
1710-
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
1711-
bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType);
1712-
1713-
if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
1714-
// Allows quantized types to mismatch
1715-
return rewriter.notifyMatchFailure(
1716-
op, "unsupported: aten.outer with different input element types");
1717-
}
1718-
1719-
Type newResultType = getTypeConverter()->convertType(op.getType());
1720-
auto resultType = cast<RankedTensorType>(newResultType);
1721-
Type elementType = resultType.getElementType();
1722-
1723-
// Quantized case
1724-
if (lhsZeroPoint) {
1725-
// get each zero point ready to pass to a quantized_matmul
1726-
lhsZeroPoint = typeConverter->materializeTargetConversion(
1727-
rewriter, loc,
1728-
getTypeConverter()->convertType(lhsZeroPoint.getType()),
1729-
lhsZeroPoint);
1730-
rhsZeroPoint = typeConverter->materializeTargetConversion(
1731-
rewriter, loc,
1732-
getTypeConverter()->convertType(rhsZeroPoint.getType()),
1733-
rhsZeroPoint);
1734-
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
1735-
loc, rewriter.getI32Type(), lhsZeroPoint);
1736-
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
1737-
loc, rewriter.getI32Type(), rhsZeroPoint);
1738-
1739-
// change uint8 quantization -> int8 quantization
1740-
int64_t numBits =
1741-
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
1742-
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
1743-
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
1744-
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
1745-
1746-
if (lhsRank == 1 && rhsRank == 1) {
1747-
int64_t lhsDim = lhsType.getShape()[0];
1748-
int64_t rhsDim = rhsType.getShape()[0];
1677+
class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
1678+
public:
1679+
using OpConversionPattern::OpConversionPattern;
1680+
LogicalResult
1681+
matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor,
1682+
ConversionPatternRewriter &rewriter) const override {
17491683

1750-
// Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m]
1751-
auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType());
1752-
auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType());
1753-
SmallVector<ReassociationIndices> reassociation = {{0, 1}};
1754-
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1755-
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1756-
1757-
// Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator.
1758-
Value lhsDimVal = rewriter.create<tensor::DimOp>(loc, lhs, 0);
1759-
Value rhsDimVal = rewriter.create<tensor::DimOp>(loc, rhs, 1);
1760-
Value zeroTensor = createZeroInitTensor(rewriter, loc,
1761-
ValueRange{lhsDimVal, rhsDimVal},
1762-
elementType);
1763-
1764-
// Use the quantized version of matmul.
1765-
Value outerProd = rewriter.create<linalg::QuantizedMatmulOp>(
1766-
loc, zeroTensor.getType(),
1767-
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
1768-
zeroTensor).getResult(0);
1769-
1770-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1771-
return success();
1772-
}
1773-
return rewriter.notifyMatchFailure(op, "unsupported: quantized aten.outer op case");
1774-
}
1775-
1776-
1777-
// Non Quantized Outter Product
1778-
if (lhsRank == 1 && rhsRank == 1) {
1779-
int64_t lhsDim = lhsType.getShape()[0];
1780-
int64_t rhsDim = rhsType.getShape()[0];
1781-
1782-
// Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m]
1783-
auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType());
1784-
auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType());
1785-
SmallVector<ReassociationIndices> reassociation = {{0, 1}};
1786-
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1787-
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1788-
1789-
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1790-
Value lhsDimVal = rewriter.create<tensor::DimOp>(loc, lhs, 0);
1791-
Value rhsDimVal = rewriter.create<tensor::DimOp>(loc, rhs, 1);
1792-
Value zeroTensor = createZeroInitTensor(rewriter, loc,
1793-
ValueRange{lhsDimVal, rhsDimVal},
1794-
elementType);
1795-
1796-
// Use linalg::MatmulOp to compute the outer product.
1797-
Value outerProd = rewriter.create<linalg::MatmulOp>(
1798-
loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor).getResult(0);
1799-
1800-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1801-
return success();
1802-
}
1803-
1684+
Location loc = op->getLoc();
1685+
Value lhs = adaptor.getSelf();
1686+
Value rhs = op->getOperand(1);
1687+
1688+
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
18041689
return failure();
18051690
}
1806-
};
1691+
auto lhsType = cast<RankedTensorType>(lhs.getType());
1692+
auto rhsType = cast<RankedTensorType>(rhs.getType());
1693+
1694+
if (!lhsType || !rhsType)
1695+
return rewriter.notifyMatchFailure(op,
1696+
"outer: expected ranked tensor types");
1697+
if (lhsType.getRank() != 1 || rhsType.getRank() != 1)
1698+
return rewriter.notifyMatchFailure(
1699+
op, "outer: expected 1D tensors for outer op lowering");
1700+
1701+
Value lhsDim = getDimOp(rewriter, loc, lhs, 1);
1702+
Value rhsDim = getDimOp(rewriter, loc, rhs, 1);
1703+
Type elementType = lhsType.getElementType();
1704+
Type newResultType = getTypeConverter()->convertType(op.getType());
1705+
1706+
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1707+
Value zeroTensor = createZeroInitTensor(
1708+
rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType);
1709+
1710+
// Set up affine indexing maps:
1711+
// We create a 2D loop iteration space. For the lhs, we use the first index
1712+
// (i), for the rhs, the second index (j), and for the result, both (i, j).
1713+
AffineMap mapLhs =
1714+
AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(0)},
1715+
rewriter.getContext());
1716+
AffineMap mapRhs =
1717+
AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(1)},
1718+
rewriter.getContext());
1719+
AffineMap mapOut =
1720+
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
1721+
1722+
SmallVector<utils::IteratorType, 2> iteratorTypes = {
1723+
utils::IteratorType::parallel, utils::IteratorType::parallel};
1724+
1725+
Value outerProd =
1726+
rewriter
1727+
.create<linalg::GenericOp>(
1728+
loc, zeroTensor.getType(),
1729+
/*inputs=*/ValueRange{lhsDim, rhsDim},
1730+
/*outputs=*/zeroTensor,
1731+
/*indexingMaps=*/
1732+
SmallVector<AffineMap, 3>{mapLhs, mapRhs, mapOut},
1733+
/*iteratortType=*/iteratorTypes,
1734+
[&](OpBuilder &b, Location loc, ValueRange args) {
1735+
Value lhsElem = args[0];
1736+
Value rhsElem = args[1];
1737+
Value mult = b.create<arith::MulFOp>(loc, lhsElem, rhsElem);
1738+
b.create<linalg::YieldOp>(loc, mult);
1739+
})
1740+
.getResult(0);
1741+
1742+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1743+
return success();
1744+
}
1745+
};
18071746
} // namespace
18081747

18091748
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,3 +918,27 @@ def forward(self, a, b):
918918
@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic())
919919
def AtenLinalgCrossDynamic_basic(module, tu: TestUtils):
920920
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))
921+
922+
923+
# ==============================================================================
924+
925+
926+
class AtenOuter(torch.nn.Module):
927+
def __init__(self):
928+
super().__init__()
929+
930+
@export
931+
@annotate_args(
932+
[
933+
None,
934+
([-1], torch.float32, True),
935+
([-1], torch.float32, True),
936+
]
937+
)
938+
def forward(self, lhs, rhs):
939+
return torch.outer(lhs, rhs)
940+
941+
942+
@register_test_case(module_factory=lambda: AtenOuter())
943+
def AtenOuter_basic(module, tu: TestUtils):
944+
module.forward(tu.rand(3), tu.rand(3))

0 commit comments

Comments
 (0)