From 5eff44aea838feacf567a0544c268220ddcbcbbe Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Sun, 14 May 2023 05:20:21 +0800 Subject: [PATCH] [Bugfix][Relay] Fix AdaptiveAvgPool2d about wrong dtype prasing (#14837) * fix adaptive_avg_pool about wrong dtype * add test case * Update test_forward.py * Update test_forward.py --- python/tvm/relay/frontend/pytorch.py | 4 ++++ tests/python/frontend/pytorch/test_forward.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1f23fe4a2c83..1ef8b6faee62 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1107,6 +1107,10 @@ def hard_swish(self, inputs, input_types): def adaptive_avg_pool(self, op, inputs, input_types): data = inputs[0] output_size = inputs[1] + for i, item in enumerate(output_size): + if isinstance(item, tvm.relay.expr.Constant): + # convert Constant to int + output_size[i] = item.data.numpy()[()] def func(x): return op(x, output_size=output_size) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index b2d0bf3a2edf..ffa37af3315a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -873,6 +873,9 @@ def test_forward_adaptive_avgpool(): verify_model(torch.nn.AdaptiveAvgPool1d([1]).eval(), input_data=input_data) verify_model(torch.nn.AdaptiveAvgPool1d([5]).eval(), input_data=input_data) + input_data = torch.rand([1, 3, 5, 6]).float() + verify_model(torch.nn.AdaptiveAvgPool2d([3, None]).eval(), input_data=input_data) + @tvm.testing.uses_gpu def test_forward_adaptive_maxpool():