From 905e87746a73594fe44df793cbe40cc445ca2ea9 Mon Sep 17 00:00:00 2001 From: tbuckworth <55180288+tbuckworth@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:24:37 +0100 Subject: [PATCH 1/6] converting sympy.NumberSymbol to torch.tensor in export_torch.py attempting to address #656 --- pysr/export_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index be3d6a163..c421f859f 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -116,7 +116,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._args = () elif issubclass(expr.func, sympy.NumberSymbol): # Can get here from exp(1) or exact pi - self._value = float(expr) + self._value = torch.tensor(float(expr)) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.Symbol): From 39f9295f6bb47dd8006520239fe0851abc038e16 Mon Sep 17 00:00:00 2001 From: tbuckworth <55180288+tbuckworth@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:26:08 +0100 Subject: [PATCH 2/6] added unit test for export_torch.py NumberSymbol --- pysr/test/test_torch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pysr/test/test_torch.py b/pysr/test/test_torch.py index 8b26f5ca6..9aae53f2b 100644 --- a/pysr/test/test_torch.py +++ b/pysr/test/test_torch.py @@ -28,6 +28,16 @@ def test_sympy2torch(self): np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy())) ) + def test_sympy2torch_number_symbol(self): + x, y, z = sympy.symbols("x y z") + expr = sin(sign(-0.041662704)) + + X = self.torch.tensor(np.random.randn(1000, 3)) + true = self.torch.sin(self.torch.tensor(-1)) + torch_module = sympy2torch(expr, [x, y, z]) + torch_out = torch_module(X) + self.assertTrue(np.isclose(torch_out.detach().numpy(), true.detach().numpy())) + def test_pipeline_pandas(self): X = pd.DataFrame(np.random.randn(100, 10)) y = np.ones(X.shape[0]) From 7dcf78683cc0dc7db165d19ebf48bfe796d64460 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:27:09 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pysr/test/test_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysr/test/test_torch.py b/pysr/test/test_torch.py index 9aae53f2b..cb692a73c 100644 --- a/pysr/test/test_torch.py +++ b/pysr/test/test_torch.py @@ -31,7 +31,7 @@ def test_sympy2torch(self): def test_sympy2torch_number_symbol(self): x, y, z = sympy.symbols("x y z") expr = sin(sign(-0.041662704)) - + X = self.torch.tensor(np.random.randn(1000, 3)) true = self.torch.sin(self.torch.tensor(-1)) torch_module = sympy2torch(expr, [x, y, z]) From 8e56560456d1e0c64d2a3e08dac54317b603dba6 Mon Sep 17 00:00:00 2001 From: tbuckworth <55180288+tbuckworth@users.noreply.github.com> Date: Mon, 30 Sep 2024 09:50:41 +0100 Subject: [PATCH 4/6] fixed minor import bug test_torch.py --- pysr/test/test_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysr/test/test_torch.py b/pysr/test/test_torch.py index cb692a73c..ebdb4af55 100644 --- a/pysr/test/test_torch.py +++ b/pysr/test/test_torch.py @@ -30,7 +30,7 @@ def test_sympy2torch(self): def test_sympy2torch_number_symbol(self): x, y, z = sympy.symbols("x y z") - expr = sin(sign(-0.041662704)) + expr = sympy.sin(sympy.sign(-0.041662704)) X = self.torch.tensor(np.random.randn(1000, 3)) true = self.torch.sin(self.torch.tensor(-1)) From c52749848c46edfffa318aa45947ae0ef8a24be3 Mon Sep 17 00:00:00 2001 From: tbuckworth Date: Sun, 6 Oct 2024 10:57:12 +0100 Subject: [PATCH 5/6] added condition for sympy.core.numbers.One --- pysr/export_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index c421f859f..e2f20a0b9 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -89,7 +89,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._sympy_func = expr.func - if issubclass(expr.func, sympy.Float): + if issubclass(expr.func, sympy.Float) or expr.func is sympy.core.numbers.One: self._value = torch.nn.Parameter(torch.tensor(float(expr))) self._torch_func = lambda: self._value self._args = () @@ -116,7 +116,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._args = () elif issubclass(expr.func, sympy.NumberSymbol): # Can get here from exp(1) or exact pi - self._value = torch.tensor(float(expr)) + self._value = float(expr) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.Symbol): From e4bc71ebf5e3a9761d10475c1fccc02c039c73a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 6 Oct 2024 09:57:25 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pysr/export_torch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index e2f20a0b9..8d703ed4c 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -89,7 +89,10 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._sympy_func = expr.func - if issubclass(expr.func, sympy.Float) or expr.func is sympy.core.numbers.One: + if ( + issubclass(expr.func, sympy.Float) + or expr.func is sympy.core.numbers.One + ): self._value = torch.nn.Parameter(torch.tensor(float(expr))) self._torch_func = lambda: self._value self._args = ()