Skip to content

Commit 3851e4e

Browse files
committed
Add conv2d e2e test from convnext model
1 parent 1b8d7e0 commit 3851e4e

File tree

1 file changed

+32
-0
lines changed
  • projects/pt1/python/torch_mlir_e2e_test/test_suite

1 file changed

+32
-0
lines changed

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,38 @@ def Convolution2DStaticModule_basic(module, tu: TestUtils):
256256
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
257257

258258

259+
class Convolution2DNextStaticModule(torch.nn.Module):
260+
def __init__(self):
261+
super().__init__()
262+
263+
@export
264+
@annotate_args(
265+
[
266+
None,
267+
([1, 80, 72, 72], torch.float32, True),
268+
([80, 1, 7, 7], torch.float32, True),
269+
([80], torch.float32, True),
270+
]
271+
)
272+
def forward(self, inputVec, weight, bias):
273+
return torch.ops.aten.convolution(
274+
inputVec,
275+
weight,
276+
bias=bias,
277+
stride=[1, 1],
278+
padding=[3, 3],
279+
dilation=[1, 1],
280+
transposed=False,
281+
output_padding=[0, 0],
282+
groups=80,
283+
)
284+
285+
286+
@register_test_case(module_factory=lambda: Convolution2DNextStaticModule())
287+
def Convolution2DNextStaticModule_basic(module, tu: TestUtils):
288+
module.forward(tu.rand(1, 80, 72, 72), tu.rand(80, 1, 7, 7), tu.rand(80))
289+
290+
259291
class Convolution2DStridedModule(torch.nn.Module):
260292
def __init__(self):
261293
super().__init__()

0 commit comments

Comments
 (0)