From 87e1999ecf17c410775bc8feec63ae8b758c02a5 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 10 Jan 2026 13:32:27 +0800 Subject: [PATCH] [Relax] Fix HardSigmoid returns 1.0 for NaN input --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 16 +++++++++++- tests/python/relax/test_frontend_onnx.py | 25 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 2212fa6c68ea..d67c43229a97 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3255,7 +3255,21 @@ def _impl_v1(cls, bb, inputs, attr, params): alpha = relax.const(alpha, dtype=dtype) beta = float(attr.get("beta", 0.5)) beta = relax.const(beta, dtype=dtype) - return relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1) + + # Detect NaN values BEFORE applying any operations that might change them + if isinstance(x, relax.Constant): + x_data = x.data.numpy() + is_nan_data = _np.isnan(x_data) + is_nan = relax.const(is_nan_data, dtype="bool") + else: + is_nan = relax.op.not_equal(x, x) + + # Apply the standard HardSigmoid computation + clipped = relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1) + + # Preserve NaN values: where x is NaN, return NaN instead of clipped value + nan_val = relax.const(_np.nan, dtype=dtype) + return relax.op.where(is_nan, nan_val, clipped) class HardSwish(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index eb4c557e754c..73129f2e7b5e 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1086,6 +1086,31 @@ def test_hardsigmoid(): verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6}) +def test_hardsigmoid_nan(): + """Test that HardSigmoid preserves NaN values in output.""" + test_node = helper.make_node("HardSigmoid", ["x"], ["y"]) + graph = helper.make_graph( + [test_node], + "hardsigmoid_nan_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4])], + ) + + model = helper.make_model(graph, producer_name="hardsigmoid_nan_test") + + # Create input with NaN values + input_data = np.array( + [ + [np.nan, 0.5, -0.5, 1.0], + [0.0, np.nan, 2.0, -2.0], + [0.3, 0.7, np.nan, np.nan], + ], + dtype=np.float32, + ) + + check_correctness(model, inputs={"x": input_data}) + + def test_shrink(): verify_unary("Shrink", [32, 32]) verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1})