From 483b87de6215338d34e4464fa2f6bd157d7de94c Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Sat, 13 May 2023 07:55:27 +0800 Subject: [PATCH] [Bugfix][Relay] Fix softplus about the wrong calculation formula in Relay PyTorch frontend (#14821) * fix softplus operator * add test cases * Update pytorch.py * Update pytorch.py --- python/tvm/relay/frontend/pytorch.py | 5 ++++- tests/python/frontend/pytorch/test_forward.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 078178a7e095..1f23fe4a2c83 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1685,7 +1685,10 @@ def func(x): def softplus(self, inputs, input_types): dtype = input_types[0] beta = _expr.const(float(inputs[1]), dtype=dtype) - return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta + threshold = int(inputs[2]) if inputs[2] else 20 + threshold_ = _op.full_like(inputs[0], fill_value=_expr.const(threshold)) + softplus_value = _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta + return _op.where(_op.greater(inputs[0] * beta, threshold_), inputs[0], softplus_value) def make_avg_pool(self, dim): def avg_pool(inputs, input_types): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index fcaf7b7847bd..b2d0bf3a2edf 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -835,6 +835,9 @@ def test_forward_softplus(): verify_model(torch.nn.Softplus().eval(), input_data=input_data) verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data) verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data) + verify_model(torch.nn.Softplus(beta=5, threshold=1).eval(), input_data=input_data) + verify_model(torch.nn.Softplus(beta=1, threshold=2).eval(), input_data=input_data) + verify_model(torch.nn.Softplus(beta=1, threshold=-1).eval(), input_data=input_data) @tvm.testing.uses_gpu