From e280034231acb03b146e6a0333e02d3fc38acebc Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 12 Oct 2025 18:29:16 +0100 Subject: [PATCH 1/2] fix: torch export with constant arguments Co-authored-by: tbuckworth <55180288+tbuckworth@users.noreply.github.com> --- pysr/export_torch.py | 6 +++--- pysr/test/test_torch.py | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index eb3ccd8aa..63b9a02d5 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -98,7 +98,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._args = () elif issubclass(expr.func, sympy.Rational): # This is some fraction fixed in the operator. - self._value = float(expr) + self.register_buffer("_value", torch.tensor(float(expr))) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.UnevaluatedExpr): @@ -114,12 +114,12 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): elif issubclass(expr.func, sympy.Integer): # Can get here if expr is one of the Integer special cases, # e.g. NegativeOne - self._value = int(expr) + self.register_buffer("_value", torch.tensor(int(expr))) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.NumberSymbol): # Can get here from exp(1) or exact pi - self._value = float(expr) + self.register_buffer("_value", torch.tensor(float(expr))) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.Symbol): diff --git a/pysr/test/test_torch.py b/pysr/test/test_torch.py index 256d21f86..e66e7d960 100644 --- a/pysr/test/test_torch.py +++ b/pysr/test/test_torch.py @@ -192,6 +192,32 @@ def test_issue_656(self): decimal=3, ) + def test_constant_arguments(self): + # Test that functions with constant arguments work correctly + # Regression test for https://github.com/MilesCranmer/PySR/issues/656 + test_cases = [ + (pysr.export_sympy.pysr2sympy("sqrt(2)"), np.sqrt(2)), + (sympy.exp(2), np.exp(2)), + (sympy.log(4), np.log(4)), + (sympy.sin(1), np.sin(1)), + ] + + for expr, expected in test_cases: + m = pysr.export_torch.sympy2torch(expr, []) + result = m(self.torch.randn(10, 1)) + np.testing.assert_almost_equal(result.item(), expected, decimal=3) + + # Test with variables: sqrt(2) * x + x = sympy.symbols("x") + expr = sympy.sqrt(2) * x + m = pysr.export_torch.sympy2torch(expr, [x]) + X = np.random.randn(10, 1) + np.testing.assert_almost_equal( + m(self.torch.tensor(X)).detach().numpy().flatten(), + np.sqrt(2) * X[:, 0], + decimal=3, + ) + def test_feature_selection_custom_operators(self): rstate = np.random.RandomState(0) X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)}) From fc55721b5c568734ea236e0ba757d79a65eae2ed Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 12 Oct 2025 21:57:55 +0100 Subject: [PATCH 2/2] refactor: remove dead branch --- pysr/export_torch.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index 63b9a02d5..ef9aeaeb8 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -97,7 +97,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.Rational): - # This is some fraction fixed in the operator. + # Includes Integer, since Integer is a subclass of Rational self.register_buffer("_value", torch.tensor(float(expr))) self._torch_func = lambda: self._value self._args = () @@ -111,12 +111,6 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self.register_buffer("_value", torch.tensor(float(expr.args[0]))) self._torch_func = lambda: self._value self._args = () - elif issubclass(expr.func, sympy.Integer): - # Can get here if expr is one of the Integer special cases, - # e.g. NegativeOne - self.register_buffer("_value", torch.tensor(int(expr))) - self._torch_func = lambda: self._value - self._args = () elif issubclass(expr.func, sympy.NumberSymbol): # Can get here from exp(1) or exact pi self.register_buffer("_value", torch.tensor(float(expr)))