Skip to content

Commit

Permalink
[Bugfix][Relay] Fix AdaptiveAvgPool2d about wrong dtype prasing (apac…
Browse files Browse the repository at this point in the history
…he#14837)

* fix adaptive_avg_pool about wrong dtype

* add test case

* Update test_forward.py

* Update test_forward.py
  • Loading branch information
jikechao authored May 13, 2023
1 parent 71d3262 commit 5eff44a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,10 @@ def hard_swish(self, inputs, input_types):
def adaptive_avg_pool(self, op, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
for i, item in enumerate(output_size):
if isinstance(item, tvm.relay.expr.Constant):
# convert Constant to int
output_size[i] = item.data.numpy()[()]

def func(x):
return op(x, output_size=output_size)
Expand Down
3 changes: 3 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,9 @@ def test_forward_adaptive_avgpool():
verify_model(torch.nn.AdaptiveAvgPool1d([1]).eval(), input_data=input_data)
verify_model(torch.nn.AdaptiveAvgPool1d([5]).eval(), input_data=input_data)

input_data = torch.rand([1, 3, 5, 6]).float()
verify_model(torch.nn.AdaptiveAvgPool2d([3, None]).eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_adaptive_maxpool():
Expand Down

0 comments on commit 5eff44a

Please sign in to comment.