Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):

self._sympy_func = expr.func

if issubclass(expr.func, sympy.Float):
if (
issubclass(expr.func, sympy.Float)
or expr.func is sympy.core.numbers.One
):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
Expand Down
10 changes: 10 additions & 0 deletions pysr/test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def test_sympy2torch(self):
np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
)

def test_sympy2torch_number_symbol(self):
x, y, z = sympy.symbols("x y z")
expr = sympy.sin(sympy.sign(-0.041662704))

X = self.torch.tensor(np.random.randn(1000, 3))
true = self.torch.sin(self.torch.tensor(-1))
torch_module = sympy2torch(expr, [x, y, z])
torch_out = torch_module(X)
self.assertTrue(np.isclose(torch_out.detach().numpy(), true.detach().numpy()))

def test_pipeline_pandas(self):
X = pd.DataFrame(np.random.randn(100, 10))
y = np.ones(X.shape[0])
Expand Down
Loading