diff --git a/pysr/export_torch.py b/pysr/export_torch.py index eb3ccd8aa..3dcbafecc 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -92,7 +92,10 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._sympy_func = expr.func - if issubclass(expr.func, sympy.Float): + if ( + issubclass(expr.func, sympy.Float) + or expr.func is sympy.core.numbers.One + ): self._value = torch.nn.Parameter(torch.tensor(float(expr))) self._torch_func = lambda: self._value self._args = () diff --git a/pysr/test/test_torch.py b/pysr/test/test_torch.py index 256d21f86..58c6179df 100644 --- a/pysr/test/test_torch.py +++ b/pysr/test/test_torch.py @@ -29,6 +29,16 @@ def test_sympy2torch(self): np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy())) ) + def test_sympy2torch_number_symbol(self): + x, y, z = sympy.symbols("x y z") + expr = sympy.sin(sympy.sign(-0.041662704)) + + X = self.torch.tensor(np.random.randn(1000, 3)) + true = self.torch.sin(self.torch.tensor(-1)) + torch_module = sympy2torch(expr, [x, y, z]) + torch_out = torch_module(X) + self.assertTrue(np.isclose(torch_out.detach().numpy(), true.detach().numpy())) + def test_pipeline_pandas(self): X = pd.DataFrame(np.random.randn(100, 10)) y = np.ones(X.shape[0])