1+ // RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s
2+ // Generally, the test cases accumulated here come from running the importer
3+ // over all included backend tests that involve simple ops with no model
4+ // level constants. This is a pragmatic choice which lets us have a lot
5+ // of tests in this file, whereas the others tend to be more bespoke.
6+
7+ // CHECK-LABEL: @test_matmul_2d
8+ func.func @test_matmul_2d (%arg0: !torch.vtensor <[3 ,4 ],f32 >, %arg1: !torch.vtensor <[4 ,3 ],f32 >) -> !torch.vtensor <[3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
9+ // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>
10+ %0 = torch.operator " onnx.MatMul" (%arg0 , %arg1 ) : (!torch.vtensor <[3 ,4 ],f32 >, !torch.vtensor <[4 ,3 ],f32 >) -> !torch.vtensor <[3 ,3 ],f32 >
11+ return %0 : !torch.vtensor <[3 ,3 ],f32 >
12+ }
13+
14+ // CHECK-LABEL: @test_matmul_3d
15+ func.func @test_matmul_3d (%arg0: !torch.vtensor <[2 ,3 ,4 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,3 ],f32 >) -> !torch.vtensor <[2 ,3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
16+ // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32> -> !torch.vtensor<[2,3,3],f32>
17+ %0 = torch.operator " onnx.MatMul" (%arg0 , %arg1 ) : (!torch.vtensor <[2 ,3 ,4 ],f32 >, !torch.vtensor <[2 ,4 ,3 ],f32 >) -> !torch.vtensor <[2 ,3 ,3 ],f32 >
18+ return %0 : !torch.vtensor <[2 ,3 ,3 ],f32 >
19+ }
20+
21+ // CHECK-LABEL: @test_matmul_4d
22+ func.func @test_matmul_4d (%arg0: !torch.vtensor <[1 ,2 ,3 ,4 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ,4 ,3 ],f32 >) -> !torch.vtensor <[1 ,2 ,3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
23+ // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32>
24+ %0 = torch.operator " onnx.MatMul" (%arg0 , %arg1 ) : (!torch.vtensor <[1 ,2 ,3 ,4 ],f32 >, !torch.vtensor <[1 ,2 ,4 ,3 ],f32 >) -> !torch.vtensor <[1 ,2 ,3 ,3 ],f32 >
25+ return %0 : !torch.vtensor <[1 ,2 ,3 ,3 ],f32 >
26+ }
0 commit comments