@@ -76,3 +76,77 @@ func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.
76
76
%2 = torch.aten.convolution %arg0 , %arg1 , %arg2 , %0 , %0 , %0 , %false , %1 , %int1 : !torch.vtensor <[1 ,80 ,3000 ],f32 >, !torch.vtensor <[1024 ,80 ,3 ],f32 >, !torch.vtensor <[1024 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,1024 ,3000 ],f32 >
77
77
return %2 : !torch.vtensor <[1 ,1024 ,3000 ],f32 >
78
78
}
79
+
80
+ // CHECK-LABEL: func.func @transposedConv2D(
81
+ // CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
82
+ // CHECK: = linalg.generic
83
+ // CHECK-SAME: outs(%[[BROADCASTED_WEIGHTS_INIT:.*]] : tensor<4x2x3x3xf32>) {
84
+ // CHECK: %[[WEIGHTS:.*]] = tensor.extract
85
+ // CHECK-SAME: : tensor<2x4x3x3xf32>
86
+ // CHECK-NEXT: linalg.yield %[[BROADCASTED_WEIGHTS:.*]] : f32
87
+ // CHECK-NEXT: } -> tensor<4x2x3x3xf32>
88
+ // CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
89
+ // CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
90
+ // CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x13x17xf32>, tensor<4x2x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x4x11x15xf32>) -> tensor<1x4x11x15xf32>
91
+ // CHECK-NEXT: %[[OUTPUT_TENSOR_DYN:.*]] = tensor.cast %[[CONV_RESULT:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
92
+ // CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.cast %[[OUTPUT_TENSOR_DYN:.*]] : tensor<1x4x?x?xf32> to tensor<1x4x10x14xf32>
93
+ func.func @transposedConv2D (%arg0: !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
94
+ %int0 = torch.constant.int 0
95
+ %true = torch.constant.bool true
96
+ %int1 = torch.constant.int 1
97
+ %int2 = torch.constant.int 2
98
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_2_4_3_3_torch.float32 > : tensor <2 x4 x3 x3 xf32 >) : !torch.vtensor <[2 ,4 ,3 ,3 ],f32 >
99
+ %1 = torch.vtensor.literal (dense_resource <torch_tensor_4_torch.float32 > : tensor <4 xf32 >) : !torch.vtensor <[4 ],f32 >
100
+ %2 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
101
+ %3 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
102
+ %4 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
103
+ %5 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
104
+ %6 = torch.aten.convolution %arg0 , %0 , %1 , %2 , %3 , %4 , %true , %5 , %int1 : !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >, !torch.vtensor <[2 ,4 ,3 ,3 ],f32 >, !torch.vtensor <[4 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 >
105
+ return %6 : !torch.vtensor <[1 ,4 ,10 ,14 ],f32 >
106
+ }
107
+
108
+ // CHECK-LABEL: func.func @groupedConvolution2D(
109
+ // CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32>
110
+ // CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x5x7xf32>) dimensions = [0, 2, 3]
111
+ // CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
112
+ // CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x2x7x9xf32>, tensor<2x2x2x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x2x2x5x7xf32>) -> tensor<1x2x2x5x7xf32>
113
+ // CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.collapse_shape
114
+ // CHECK-SAME: tensor<1x2x2x5x7xf32> into tensor<1x4x5x7xf32>
115
+ func.func @groupedConvolution2D (%arg0: !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,5 ,7 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
116
+ %int0 = torch.constant.int 0
117
+ %false = torch.constant.bool false
118
+ %int1 = torch.constant.int 1
119
+ %int2 = torch.constant.int 2
120
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_4_2_3_3_torch.float32 > : tensor <4 x2 x3 x3 xf32 >) : !torch.vtensor <[4 ,2 ,3 ,3 ],f32 >
121
+ %1 = torch.vtensor.literal (dense_resource <torch_tensor_4_torch.float32 > : tensor <4 xf32 >) : !torch.vtensor <[4 ],f32 >
122
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
123
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
124
+ %4 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
125
+ %5 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
126
+ %6 = torch.aten.convolution %arg0 , %0 , %1 , %2 , %3 , %4 , %false , %5 , %int2 : !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >, !torch.vtensor <[4 ,2 ,3 ,3 ],f32 >, !torch.vtensor <[4 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >
127
+ return %6 : !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >
128
+ }
129
+
130
+ // CHECK-LABEL: func.func @transposedGroupedConvolution2D(
131
+ // CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
132
+ // CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
133
+ // CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
134
+ // CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x1x13x17xf32>, tensor<2x2x1x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x2x2x11x15xf32>) -> tensor<1x2x2x11x15xf32>
135
+ // CHECK-NEXT: %[[COLLAPSED_TENSOR:.*]] = tensor.collapse_shape
136
+ // CHECK-SAME: tensor<1x2x2x11x15xf32> into tensor<1x4x11x15xf32>
137
+ // CHECK-NEXT: %[[OUTPUT_TENSOR_DYN:.*]] = tensor.cast %[[COLLAPSED_TENSOR:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
138
+ // CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.cast %[[OUTPUT_TENSOR_DYN:.*]] : tensor<1x4x?x?xf32> to tensor<1x4x10x14xf32>
139
+ func.func @transposedGroupedConvolution2D (%arg0: !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
140
+ %int0 = torch.constant.int 0
141
+ %true = torch.constant.bool true
142
+ %int1 = torch.constant.int 1
143
+ %int2 = torch.constant.int 2
144
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_2_2_3_3_torch.float32 > : tensor <2 x2 x3 x3 xf32 >) : !torch.vtensor <[2 ,2 ,3 ,3 ],f32 >
145
+ %1 = torch.vtensor.literal (dense_resource <torch_tensor_4_torch.float32 > : tensor <4 xf32 >) : !torch.vtensor <[4 ],f32 >
146
+ %2 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
147
+ %3 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
148
+ %4 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
149
+ %5 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
150
+ %6 = torch.aten.convolution %arg0 , %0 , %1 , %2 , %3 , %4 , %true , %5 , %int2 : !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >, !torch.vtensor <[2 ,2 ,3 ,3 ],f32 >, !torch.vtensor <[4 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 >
151
+ return %6 : !torch.vtensor <[1 ,4 ,10 ,14 ],f32 >
152
+ }
0 commit comments