Skip to content

Commit a24aadb

Browse files
authored
[aten] Make torch.aten.matmul to linalg work for non-broadcasting case (#2659)
Broadcasting for `torch.aten.matmul` is optional so a MxN with NxK matmul should be legalized to a `linalg.matmul`.
1 parent 8fa81d1 commit a24aadb

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,9 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
191191
Value lhs = adaptor.getSelf();
192192
Value rhs = adaptor.getOther();
193193

194-
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
194+
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
195195
return failure();
196+
}
196197
auto lhsType = lhs.getType().cast<RankedTensorType>();
197198
auto rhsType = rhs.getType().cast<RankedTensorType>();
198199

@@ -260,7 +261,26 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
260261
return success();
261262
}
262263

263-
// Fourth Case: Batch-Matrix Multiplication.
264+
// Fourth Case: Vec-Vec Multiplication.
265+
if (lhsRank == 2 && rhsRank == 2) {
266+
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
267+
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
268+
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
269+
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
270+
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
271+
272+
Value zeroTensor = createZeroInitTensor(
273+
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
274+
Value matmul =
275+
rewriter
276+
.create<linalg::MatmulOp>(loc, zeroTensor.getType(),
277+
ValueRange{lhs, rhs}, zeroTensor)
278+
.getResult(0);
279+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
280+
return success();
281+
}
282+
283+
// Fifth Case: Batch-Matrix Multiplication.
264284
// TODO: Handle batch matrix multiplication when one of the matrix is unity
265285
// rank and the other has batch dimension.
266286
if (lhsRank > 1 && rhsRank > 1) {

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
2929

3030
// -----
3131

32+
// CHECK-LABEL: func.func @torch.aten.matmul.2d
33+
func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
34+
// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32>
35+
// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
36+
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
37+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32>
38+
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<8x8xf32>) -> tensor<8x8xf32>
39+
// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<8x16xf32>, tensor<16x8xf32>) outs(%[[FILL]] : tensor<8x8xf32>) -> tensor<8x8xf32>
40+
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[8,16],f32>, !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
41+
return %0 : !torch.vtensor<[8,8],f32>
42+
}
43+
44+
// -----
45+
3246
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
3347
// CHECK-NOT: assert
3448
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>

0 commit comments

Comments
 (0)