@@ -191,8 +191,9 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
191
191
Value lhs = adaptor.getSelf ();
192
192
Value rhs = adaptor.getOther ();
193
193
194
- if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
194
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
195
195
return failure ();
196
+ }
196
197
auto lhsType = lhs.getType ().cast <RankedTensorType>();
197
198
auto rhsType = rhs.getType ().cast <RankedTensorType>();
198
199
@@ -260,7 +261,26 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
260
261
return success ();
261
262
}
262
263
263
- // Fourth Case: Batch-Matrix Multiplication.
264
+ // Fourth Case: Vec-Vec Multiplication.
265
+ if (lhsRank == 2 && rhsRank == 2 ) {
266
+ Value lhsDim0 = getDimOp (rewriter, loc, lhs, 0 );
267
+ Value lhsDim1 = getDimOp (rewriter, loc, lhs, 1 );
268
+ Value rhsDim0 = getDimOp (rewriter, loc, rhs, 0 );
269
+ Value rhsDim1 = getDimOp (rewriter, loc, rhs, 1 );
270
+ checkDimEqualHelper (rewriter, loc, lhsDim1, rhsDim0);
271
+
272
+ Value zeroTensor = createZeroInitTensor (
273
+ rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
274
+ Value matmul =
275
+ rewriter
276
+ .create <linalg::MatmulOp>(loc, zeroTensor.getType (),
277
+ ValueRange{lhs, rhs}, zeroTensor)
278
+ .getResult (0 );
279
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, matmul);
280
+ return success ();
281
+ }
282
+
283
+ // Fifth Case: Batch-Matrix Multiplication.
264
284
// TODO: Handle batch matrix multiplication when one of the matrix is unity
265
285
// rank and the other has batch dimension.
266
286
if (lhsRank > 1 && rhsRank > 1 ) {
0 commit comments