Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Rational):
# This is some fraction fixed in the operator.
self._value = float(expr)
# Includes Integer, since Integer is a subclass of Rational
self.register_buffer("_value", torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.UnevaluatedExpr):
Expand All @@ -111,15 +111,9 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
self.register_buffer("_value", torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Integer):
# Can get here if expr is one of the Integer special cases,
# e.g. NegativeOne
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.NumberSymbol):
# Can get here from exp(1) or exact pi
self._value = float(expr)
self.register_buffer("_value", torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Symbol):
Expand Down
26 changes: 26 additions & 0 deletions pysr/test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,32 @@ def test_issue_656(self):
decimal=3,
)

def test_constant_arguments(self):
# Test that functions with constant arguments work correctly
# Regression test for https://github.com/MilesCranmer/PySR/issues/656
test_cases = [
(pysr.export_sympy.pysr2sympy("sqrt(2)"), np.sqrt(2)),
(sympy.exp(2), np.exp(2)),
(sympy.log(4), np.log(4)),
(sympy.sin(1), np.sin(1)),
]

for expr, expected in test_cases:
m = pysr.export_torch.sympy2torch(expr, [])
result = m(self.torch.randn(10, 1))
np.testing.assert_almost_equal(result.item(), expected, decimal=3)

# Test with variables: sqrt(2) * x
x = sympy.symbols("x")
expr = sympy.sqrt(2) * x
m = pysr.export_torch.sympy2torch(expr, [x])
X = np.random.randn(10, 1)
np.testing.assert_almost_equal(
m(self.torch.tensor(X)).detach().numpy().flatten(),
np.sqrt(2) * X[:, 0],
decimal=3,
)

def test_feature_selection_custom_operators(self):
rstate = np.random.RandomState(0)
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
Expand Down
Loading